From 524158812179ceace0d0f1621bae76df15435188 Mon Sep 17 00:00:00 2001 From: Hongsheng Liu Date: Tue, 10 Mar 2026 22:32:05 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"Add=20online=20serving=20to=20Stable?= =?UTF-8?q?=20Audio=20Diffusion=20and=20introduce=20`v1/audio/=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 7a560492a2fd4fb17ffd2ded65a214f915567cff. --- docs/serving/audio_generate_api.md | 338 ------------ .../examples/online_serving/text_to_audio.md | 193 ------- .../online_serving/stable_audio/README.md | 234 -------- .../stable_audio/curl_examples.sh | 54 -- .../stable_audio/stable_audio_client.py | 170 ------ .../openai_api/test_serving_audio_generate.py | 509 ------------------ tests/entrypoints/test_omni_diffusion.py | 18 +- tests/entrypoints/test_omni_llm.py | 16 +- .../stable_audio/pipeline_stable_audio.py | 4 +- vllm_omni/entrypoints/omni_llm.py | 6 +- vllm_omni/entrypoints/openai/api_server.py | 78 +-- .../entrypoints/openai/audio_utils_mixin.py | 4 - .../entrypoints/openai/protocol/audio.py | 47 -- .../openai/serving_audio_generate.py | 168 ------ .../entrypoints/openai/serving_speech.py | 7 +- vllm_omni/entrypoints/utils.py | 77 ++- 16 files changed, 63 insertions(+), 1860 deletions(-) delete mode 100644 docs/serving/audio_generate_api.md delete mode 100644 docs/user_guide/examples/online_serving/text_to_audio.md delete mode 100644 examples/online_serving/stable_audio/README.md delete mode 100755 examples/online_serving/stable_audio/curl_examples.sh delete mode 100755 examples/online_serving/stable_audio/stable_audio_client.py delete mode 100644 tests/entrypoints/openai_api/test_serving_audio_generate.py delete mode 100644 vllm_omni/entrypoints/openai/serving_audio_generate.py diff --git a/docs/serving/audio_generate_api.md b/docs/serving/audio_generate_api.md deleted file mode 100644 index e7eaef1860b..00000000000 --- a/docs/serving/audio_generate_api.md +++ /dev/null @@ -1,338 +0,0 @@ -# Audio Generate API - -vLLM-Omni provides an API for text-to-audio generation using diffusion-based models such as Stable Audio. - -Unlike the [Speech API](speech_api.md) which targets text-to-speech synthesis, the Audio Generate API is designed for general-purpose audio generation from text descriptions (sound effects, music, ambient soundscapes, etc.). - -Each server instance runs a single model (specified at startup via `vllm-omni serve --omni`). - -## Quick Start - -### Start the Server - -```bash -vllm-omni serve stabilityai/stable-audio-open-1.0 \ - --host 0.0.0.0 \ - --port 8000 \ - --gpu-memory-utilization 0.9 \ - --trust-remote-code \ - --enforce-eager \ - --omni -``` - -### Generate Audio - -**Using curl:** - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "The sound of a cat purring", - "audio_length": 10.0 - }' --output cat.wav -``` - -**Using Python:** - -```python -import httpx - -response = httpx.post( - "http://localhost:8000/v1/audio/generate", - json={ - "input": "The sound of a cat purring", - "audio_length": 10.0, - }, - timeout=300.0, -) - -with open("cat.wav", "wb") as f: - f.write(response.content) -``` - -## API Reference - -### Endpoint - -``` -POST /v1/audio/generate -Content-Type: application/json -``` - -### Request Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `input` | string | **required** | Text prompt describing the audio to generate | -| `model` | string | server's model | Model to use (optional, should match server if specified) | -| `response_format` | string | "wav" | Audio format: wav, mp3, flac, pcm, aac, opus | -| `speed` | float | 1.0 | Playback speed (0.25 - 4.0) | - -#### Diffusion Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `audio_length` | float | null | Audio duration in seconds (default value is the max ~47s for `stable-audio-open-1.0`) | -| `audio_start` | float | 0.0 | Audio start time in seconds | -| `negative_prompt` | string | null | Text describing what to avoid in generation | -| `guidance_scale` | float | model default | Classifier-free guidance scale (higher = more adherence to prompt) | -| `num_inference_steps` | int | model default | Number of denoising steps (higher = better quality, slower) | -| `seed` | int | null | Random seed for reproducible generation | - -### Response Format - -Returns binary audio data with the appropriate `Content-Type` header: - -| `response_format` | Content-Type | -|--------------------|--------------| -| `wav` | `audio/wav` | -| `mp3` | `audio/mpeg` | -| `flac` | `audio/flac` | -| `pcm` | `audio/pcm` | -| `aac` | `audio/aac` | -| `opus` | `audio/opus` | - -## Examples - -### Basic Generation - -Generate audio with only a text prompt (model defaults for all other parameters): - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "The sound of ocean waves crashing on a beach" - }' --output ocean.wav -``` - -### Custom Duration - -Specify an explicit audio length in seconds: - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "A dog barking", - "audio_length": 5.0 - }' --output dog_5s.wav -``` - -### High Quality with Negative Prompt - -Use a negative prompt to steer generation away from undesired characteristics, and increase inference steps for higher quality: - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "A piano playing a gentle melody", - "audio_length": 10.0, - "negative_prompt": "Low quality, distorted, noisy", - "guidance_scale": 8.0, - "num_inference_steps": 150 - }' --output piano_hq.wav -``` - -### Reproducible Generation - -Set a `seed` to get deterministic results across runs: - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "Thunder and rain sounds", - "audio_length": 15.0, - "seed": 42 - }' --output thunder.wav -``` - -### Full Control - -Combine all parameters for precise control over generation: - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "Thunder and rain sounds", - "audio_length": 15.0, - "negative_prompt": "Low quality", - "guidance_scale": 7.0, - "num_inference_steps": 100, - "seed": 42 - }' --output thunder_rain.wav -``` - -### Quick Generation (Fewer Steps) - -For faster generation with slightly lower quality: - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "Birds chirping in a forest", - "audio_length": 8.0, - "num_inference_steps": 50 - }' --output birds_quick.wav -``` - -### Python Client - -```python -import httpx - -response = httpx.post( - "http://localhost:8000/v1/audio/generate", - json={ - "input": "Thunder and rain", - "audio_length": 15.0, - "negative_prompt": "Low quality", - "guidance_scale": 7.0, - "num_inference_steps": 100, - "seed": 42, - "response_format": "wav", - }, - timeout=300.0, -) - -with open("thunder.wav", "wb") as f: - f.write(response.content) -``` - -## Parameter Tuning Guide - -### `guidance_scale` - -Controls how closely the generated audio follows the text prompt. - -| Range | Behaviour | -|-------|-----------| -| 3 - 5 | More creative / varied output | -| 7 (default) | Balanced adherence | -| 10+ | Strict adherence to the prompt | - -### `num_inference_steps` - -Controls the number of denoising steps in the diffusion process. - -| Steps | Quality | Speed | Use Case | -|-------|---------|-------|----------| -| 50 | Good | Fast | Quick previews | -| 100 | Very Good | Medium | General purpose | -| 150+ | Excellent | Slow | Final / critical audio | - -### `audio_length` - -Duration of the generated audio clip. For `stable-audio-open-1.0`, the maximum is approximately 47 seconds. If omitted, the model uses its own default length. - -### `negative_prompt` - -Describes characteristics to avoid. Common negative prompts include: - -- `"Low quality, distorted, noisy"` -- `"Silence, static"` -- `"Music"` (when generating sound effects only) - -## Supported Models - -| Model | Description | -|-------|-------------| -| `stabilityai/stable-audio-open-1.0` | Open-source audio generation model, up to ~47 seconds, 44.1 kHz stereo | - -## Error Responses - -### 400 Bad Request - -Invalid or missing parameters: - -```json -{ - "error": { - "message": "Audio generation model did not produce audio output.", - "type": "BadRequestError", - "param": null, - "code": 400 - } -} -``` - -### 404 Not Found - -Model mismatch: - -```json -{ - "error": { - "message": "The model `xxx` does not exist.", - "type": "NotFoundError", - "param": "model", - "code": 404 - } -} -``` - -### 422 Unprocessable Entity - -Pydantic validation failure (e.g. invalid `response_format`, `speed` out of range): - -```json -{ - "detail": [ - { - "type": "literal_error", - "msg": "Input should be 'wav', 'pcm', 'flac', 'mp3', 'aac' or 'opus'", - ... - } - ] -} -``` - -## Troubleshooting - -### "Audio generation model did not produce audio output" - -The model finished but returned no audio data. Verify the server started successfully and the model loaded without errors. - -### Server Not Responding - -```bash -# Check if the server is healthy -curl http://localhost:8000/health -``` - -### Audio Quality Issues - -- Increase `num_inference_steps` (e.g. 150). -- Add a negative prompt: `"Low quality, distorted, noisy"`. -- Increase `guidance_scale` for stronger prompt adherence. - -### Generation Timeout - -- Reduce `num_inference_steps`. -- Reduce `audio_length`. -- Check GPU memory with `nvidia-smi`. - -### Out of Memory - -- Lower `--gpu-memory-utilization` (e.g. 0.8). -- Reduce `audio_length`. - -## Development - -Enable debug logging: - -```bash -vllm-omni serve stabilityai/stable-audio-open-1.0 \ - --host 0.0.0.0 \ - --port 8000 \ - --gpu-memory-utilization 0.9 \ - --trust-remote-code \ - --enforce-eager \ - --omni \ - --uvicorn-log-level debug -``` diff --git a/docs/user_guide/examples/online_serving/text_to_audio.md b/docs/user_guide/examples/online_serving/text_to_audio.md deleted file mode 100644 index b84bde6c482..00000000000 --- a/docs/user_guide/examples/online_serving/text_to_audio.md +++ /dev/null @@ -1,193 +0,0 @@ -# Text-To-Audio - -Source . - -This example demonstrates how to deploy Stable Audio models for online text-to-audio generation using vLLM-Omni. - -## Supported Models - -| Model | Description | -|-------|-------------| -| `stabilityai/stable-audio-open-1.0` | Open-source audio generation, up to ~47 seconds, 44.1 kHz stereo | - -## Start Server - -### Basic Start - -```bash -vllm-omni serve stabilityai/stable-audio-open-1.0 \ - --host 0.0.0.0 \ - --port 8000 \ - --gpu-memory-utilization 0.9 \ - --trust-remote-code \ - --enforce-eager \ - --omni -``` - -## API Calls - -### Method 1: Using curl - -```bash -# Run all curl examples -bash curl_examples.sh - -# Or execute directly -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "The sound of a cat purring", - "audio_length": 10.0 - }' --output cat.wav -``` - -### Method 2: Using Python Client - -```bash -cd examples/online_serving/stable_audio - -# Simple generation -python stable_audio_client.py \ - --text "The sound of a cat purring" - -# With custom duration -python stable_audio_client.py \ - --text "A dog barking" \ - --audio_length 5.0 - -# With all parameters -python stable_audio_client.py \ - --text "Thunder and rain" \ - --audio_length 15.0 \ - --negative_prompt "Low quality" \ - --guidance_scale 7.0 \ - --num_inference_steps 100 \ - --seed 42 \ - --output thunder.wav -``` - -The Python client supports the following command-line arguments: - -- `--api_url`: API endpoint URL (default: `http://localhost:8000/v1/audio/generate`) -- `--text`: Text prompt for audio generation (default: `"The sound of a cat purring"`) -- `--audio_length`: Audio length in seconds (default: `10.0`, max ~47s for `stable-audio-open-1.0`) -- `--audio_start`: Audio start time in seconds (default: `0.0`) -- `--negative_prompt`: Negative prompt for classifier-free guidance (default: `"Low quality"`) -- `--guidance_scale`: Guidance scale for diffusion (default: `7.0`) -- `--num_inference_steps`: Number of inference steps (default: `100`) -- `--seed`: Random seed for reproducibility (default: `None`) -- `--response_format`: Audio output format (default: `wav`). Options: `wav`, `mp3`, `flac`, `pcm` -- `--output`: Output file path (default: `stable_audio_output.wav`) - -### Method 3: Using Python httpx - -```python -import httpx - -response = httpx.post( - "http://localhost:8000/v1/audio/generate", - json={ - "input": "The sound of ocean waves crashing on a beach", - "audio_length": 10.0, - "negative_prompt": "Low quality, distorted", - "guidance_scale": 7.0, - "num_inference_steps": 100, - }, - timeout=300.0, -) - -with open("ocean.wav", "wb") as f: - f.write(response.content) -``` - -## Request Format - -### Simple Generation - -```json -{ - "input": "The sound of ocean waves" -} -``` - -### Generation with Parameters - -```json -{ - "input": "A piano playing a gentle melody", - "audio_length": 10.0, - "negative_prompt": "Low quality, distorted, noisy", - "guidance_scale": 8.0, - "num_inference_steps": 150, - "seed": 42, - "response_format": "wav" -} -``` - -## API Reference - -### Endpoint - -``` -POST /v1/audio/generate -Content-Type: application/json -``` - -### Generation Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `input` | string | **required** | Text prompt describing the audio to generate | -| `model` | string | server's model | Model to use (optional, should match server if specified) | -| `response_format` | string | "wav" | Audio format: wav, mp3, flac, pcm, aac, opus | -| `speed` | float | 1.0 | Playback speed (0.25 - 4.0) | -| `audio_length` | float | null | Audio duration in seconds (max ~47s for `stable-audio-open-1.0`) | -| `audio_start` | float | 0.0 | Audio start time in seconds | -| `negative_prompt` | string | null | Text describing what to avoid in generation | -| `guidance_scale` | float | model default | Classifier-free guidance scale (higher = more adherence to prompt) | -| `num_inference_steps` | int | model default | Number of denoising steps (higher = better quality, slower) | -| `seed` | int | null | Random seed for reproducible generation | - -### Response Format - -Returns binary audio data with appropriate `Content-Type` header (e.g., `audio/wav`). - -## Tuning Tips - -1. **Audio Length**: Keep under 47 seconds for `stable-audio-open-1.0`. -2. **Quality vs Speed**: - - 50 steps: Fast, decent quality (quick previews) - - 100 steps: Good balance (general purpose) - - 150+ steps: High quality, slower (final / critical audio) -3. **Guidance Scale**: - - Lower (3 - 5): More creative / varied output - - Default (7): Good balance - - Higher (10+): Strict adherence to the prompt -4. **Negative Prompts**: Use to avoid unwanted characteristics such as `"Low quality"`, `"distorted"`, `"noisy"`. -5. **Seeds**: Set a fixed seed to get deterministic, reproducible results. - -## File Description - -| File | Description | -|------|-------------| -| `curl_examples.sh` | Curl examples covering common use cases | -| `stable_audio_client.py` | Python client with full CLI argument support | - -## Troubleshooting - -1. **Audio generation model did not produce audio output**: Verify the server started successfully and the model loaded without errors. -2. **Connection refused**: Make sure the server is running on the correct port. -3. **Generation timeout**: Reduce `num_inference_steps` or `audio_length`, and check GPU memory with `nvidia-smi`. -4. **Out of memory**: Lower `--gpu-memory-utilization` or reduce `audio_length`. -5. **Audio quality issues**: Increase `num_inference_steps`, add a negative prompt, or raise `guidance_scale`. - -## Example materials - -??? abstract "stable_audio_client.py" - ``````py - --8<-- "examples/online_serving/stable_audio/stable_audio_client.py" - `````` -??? abstract "curl_examples.sh" - ``````sh - --8<-- "examples/online_serving/stable_audio/curl_examples.sh" - `````` diff --git a/examples/online_serving/stable_audio/README.md b/examples/online_serving/stable_audio/README.md deleted file mode 100644 index c8ffaadcaae..00000000000 --- a/examples/online_serving/stable_audio/README.md +++ /dev/null @@ -1,234 +0,0 @@ -# Stable Audio Online Serving - -Generate audio from text prompts using Stable Audio models via an OpenAI-compatible API endpoint. - -## Features - -- **OpenAI-compatible API**: Use `/v1/audio/generate` endpoint -- **Flexible control**: Adjust audio length, guidance scale, inference steps -- **Quality control**: Use negative prompts to avoid unwanted characteristics -- **Reproducible**: Set random seed for deterministic generation - -## Quick Start - -### 1. Start the Server - -```bash -vllm-omni serve stabilityai/stable-audio-open-1.0 \ - --host 0.0.0.0 \ - --port 8000 \ - --gpu-memory-utilization 0.9 \ - --trust-remote-code \ - --enforce-eager \ - --omni -``` - -### 2. Generate Audio - -#### Using curl - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "The sound of a cat purring", - "audio_length": 10.0 - }' --output cat.wav -``` - -#### Using Python Client - -```bash -python stable_audio_client.py \ - --text "The sound of a cat purring" \ - --audio_length 10.0 \ - --output cat.wav -``` - -#### Using Bash Script - -```bash -bash curl_examples.sh -``` - -## API Reference - -### Endpoint - -``` -POST /v1/audio/generate -``` - -### Request Body - -```json -{ - "input": "Text description of the audio", - "audio_length": 10.0, - "audio_start": 0.0, - "negative_prompt": "Low quality", - "guidance_scale": 7.0, - "num_inference_steps": 100, - "seed": 42, - "response_format": "wav" -} -``` - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `input` | string | **required** | Text prompt describing the audio to generate | -| `audio_length` | float | ~47s | Audio duration in seconds (max ~47s for stable-audio-open-1.0) | -| `audio_start` | float | 0.0 | Audio start time in seconds | -| `negative_prompt` | string | null | Text describing what to avoid in generation | -| `guidance_scale` | float | 7.0 | Classifier-free guidance scale (higher = more adherence to prompt) | -| `num_inference_steps` | int | 50 | Number of denoising steps (higher = better quality, slower) | -| `seed` | int | null | Random seed for reproducibility | -| `response_format` | string | "wav" | Output format: wav, mp3, flac, pcm | - -### Response - -Returns audio data in the requested format (default: WAV). - -## Usage Examples - -### Basic Generation - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "The sound of ocean waves" - }' --output ocean.wav -``` - -### Custom Duration - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "A dog barking", - "audio_length": 5.0 - }' --output dog_5s.wav -``` - -### High Quality with Negative Prompt - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "A piano playing a gentle melody", - "audio_length": 10.0, - "negative_prompt": "Low quality, distorted, noisy", - "guidance_scale": 8.0, - "num_inference_steps": 150 - }' --output piano_hq.wav -``` - -### Reproducible Generation - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "Thunder and rain sounds", - "audio_length": 15.0, - "seed": 42 - }' --output thunder.wav -``` - -### Quick Generation (Fewer Steps) - -For faster generation with slightly lower quality: - -```bash -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "Birds chirping in a forest", - "audio_length": 8.0, - "num_inference_steps": 50 - }' --output birds_quick.wav -``` - -## Python Client Examples - -### Simple Generation - -```bash -python stable_audio_client.py \ - --text "The sound of a cat purring" -``` - -### Custom Parameters - -```bash -python stable_audio_client.py \ - --text "Thunder and rain" \ - --audio_length 15.0 \ - --negative_prompt "Low quality" \ - --guidance_scale 7.0 \ - --num_inference_steps 100 \ - --seed 42 \ - --output thunder.wav -``` - -### Different Output Format - -```bash -python stable_audio_client.py \ - --text "Guitar playing" \ - --response_format mp3 \ - --output guitar.mp3 -``` - -## Tips - -1. **Audio Length**: Keep under 47 seconds for `stable-audio-open-1.0` -2. **Quality vs Speed**: - - 50 steps: Fast, decent quality - - 100 steps: Good balance (default) - - 150+ steps: High quality, slower -3. **Guidance Scale**: - - Lower (3-5): More creative/varied - - Default (7): Good balance - - Higher (10+): More literal to prompt -4. **Negative Prompts**: Use to avoid "Low quality", "distorted", "noisy", etc. -5. **Seeds**: Use same seed for reproducible results - -## Performance - -| Inference Steps | Quality | Speed | Use Case | -|----------------|---------|-------|----------| -| 50 | Good | Fast | Quick previews | -| 100 (default) | Very Good | Medium | Production | -| 150+ | Excellent | Slow | Final/critical audio | - -## Troubleshooting - -### Server not responding -- Check if server is running: `curl http://localhost:8000/health` -- Check server logs for errors - -### Audio quality issues -- Increase `num_inference_steps` (e.g., 150) -- Add negative prompts: `"Low quality, distorted, noisy"` -- Increase `guidance_scale` for more prompt adherence - -### Generation timeout -- Reduce `num_inference_steps` -- Reduce `audio_length` -- Check GPU memory with `nvidia-smi` - -### Wrong audio length -- Ensure `audio_length` is within model limits (~47s max) -- Adjust `audio_start` if trimming is needed - -## See Also - -- [Offline Inference Example](../../offline_inference/text_to_audio/README.md) -- [Stable Audio Model Card](https://huggingface.co/stabilityai/stable-audio-open-1.0) -- [vLLM-Omni Documentation](https://github.com/vllm-project/vllm-omni) diff --git a/examples/online_serving/stable_audio/curl_examples.sh b/examples/online_serving/stable_audio/curl_examples.sh deleted file mode 100755 index 8c4b0c8463f..00000000000 --- a/examples/online_serving/stable_audio/curl_examples.sh +++ /dev/null @@ -1,54 +0,0 @@ -#!/bin/bash -# Examples for using Stable Audio with curl via /v1/audio/generate endpoint - -# Example 1: Simple request with default parameters -echo "Example 1: Simple request with default parameters" -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "The sound audience clapping and cheering in a stadium" - }' --output stadium.wav - -# Example 2: Request with custom audio_length -echo "Example 2: Custom audio length (5 seconds)" -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "The sound of a dog barking", - "audio_length": 5.0 - }' --output dog_5s.wav - -# Example 3: Request with negative prompt for quality control -echo "Example 3: With negative prompt" -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "A piano playing a gentle melody", - "audio_length": 10.0, - "negative_prompt": "Low quality, distorted, noisy" - }' --output piano.wav - -# Example 4: Full control with all parameters -echo "Example 4: Full control (custom length, guidance, steps, seed)" -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "Thunder and rain sounds", - "audio_length": 15.0, - "negative_prompt": "Low quality", - "guidance_scale": 7.0, - "num_inference_steps": 100, - "seed": 42 - }' --output thunder_rain.wav - -# Example 5: Quick generation with fewer steps (faster but lower quality) -echo "Example 5: Quick generation (fewer steps)" -curl -X POST http://localhost:8000/v1/audio/generate \ - -H "Content-Type: application/json" \ - -d '{ - "input": "Ocean waves crashing on a beach", - "audio_length": 8.0, - "num_inference_steps": 50 - }' --output ocean.wav - -echo "All examples completed!" diff --git a/examples/online_serving/stable_audio/stable_audio_client.py b/examples/online_serving/stable_audio/stable_audio_client.py deleted file mode 100755 index 8e6a59d3647..00000000000 --- a/examples/online_serving/stable_audio/stable_audio_client.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python3 -""" -OpenAI-compatible client for Stable Audio via /v1/audio/generate endpoint. - -This script demonstrates how to use the OpenAI-compatible speech API -to generate audio from text using Stable Audio models. - -Examples: - # Simple generation - python stable_audio_client.py --text "The sound of a cat purring" - - # With custom duration - python stable_audio_client.py --text "A dog barking" --audio_length 5.0 - - # With all parameters - python stable_audio_client.py --text "Thunder and rain" \ - --audio_length 15.0 \ - --negative_prompt "Low quality" \ - --guidance_scale 7.0 \ - --num_inference_steps 100 \ - --seed 42 \ - --output thunder.wav -""" - -import argparse -import sys - -import requests - - -def parse_args(): - parser = argparse.ArgumentParser(description="Generate audio with Stable Audio via OpenAI-compatible API") - parser.add_argument( - "--api_url", - default="http://localhost:8000/v1/audio/generate", - help="API endpoint URL", - ) - parser.add_argument( - "--text", - default="The sound of a cat purring", - help="Text prompt for audio generation", - ) - parser.add_argument( - "--audio_length", - type=float, - default=10.0, - help="Audio length in seconds (max ~47s for stable-audio-open-1.0)", - ) - parser.add_argument( - "--audio_start", - type=float, - default=0.0, - help="Audio start time in seconds", - ) - parser.add_argument( - "--negative_prompt", - default="Low quality", - help="Negative prompt for classifier-free guidance", - ) - parser.add_argument( - "--guidance_scale", - type=float, - default=7.0, - help="Guidance scale for diffusion (higher = more adherence to prompt)", - ) - parser.add_argument( - "--num_inference_steps", - type=int, - default=100, - help="Number of inference steps (higher = better quality, slower)", - ) - parser.add_argument( - "--seed", - type=int, - default=None, - help="Random seed for reproducibility", - ) - parser.add_argument( - "--output", - default="stable_audio_output.wav", - help="Output file path", - ) - parser.add_argument( - "--response_format", - default="wav", - choices=["wav", "mp3", "flac", "pcm"], - help="Audio output format", - ) - return parser.parse_args() - - -def generate_audio(args): - """Generate audio using the API.""" - - # Build request payload - payload = { - "input": args.text, - "audio_length": args.audio_length, - "audio_start": args.audio_start, - "response_format": args.response_format, - } - - # Add optional parameters - if args.negative_prompt: - payload["negative_prompt"] = args.negative_prompt - if args.guidance_scale: - payload["guidance_scale"] = args.guidance_scale - if args.num_inference_steps: - payload["num_inference_steps"] = args.num_inference_steps - if args.seed is not None: - payload["seed"] = args.seed - - print(f"\n{'=' * 60}") - print("Stable Audio - Text-to-Audio Generation") - print(f"{'=' * 60}") - print(f"API URL: {args.api_url}") - print(f"Prompt: {args.text}") - print(f"Audio length: {args.audio_length}s") - print(f"Negative prompt: {args.negative_prompt}") - print(f"Guidance scale: {args.guidance_scale}") - print(f"Inference steps: {args.num_inference_steps}") - if args.seed is not None: - print(f"Seed: {args.seed}") - print(f"Output: {args.output}") - print(f"{'=' * 60}\n") - - try: - # Make the API request - print("Generating audio...") - response = requests.post( - args.api_url, - json=payload, - headers={"Content-Type": "application/json"}, - timeout=300, # 5 minute timeout for long generations - ) - - # Check for errors - if response.status_code != 200: - print(f"Error: API returned status code {response.status_code}") - print(f"Response: {response.text}") - return False - - # Save the audio - with open(args.output, "wb") as f: - f.write(response.content) - - print(f"✓ Audio saved to {args.output}") - print(f" File size: {len(response.content) / 1024:.1f} KB") - return True - - except requests.exceptions.Timeout: - print("Error: Request timed out. Try reducing inference steps or audio length.") - return False - except requests.exceptions.ConnectionError: - print(f"Error: Could not connect to {args.api_url}") - print("Make sure the server is running.") - return False - except Exception as e: - print(f"Error: {e}") - return False - - -def main(): - args = parse_args() - success = generate_audio(args) - sys.exit(0 if success else 1) - - -if __name__ == "__main__": - main() diff --git a/tests/entrypoints/openai_api/test_serving_audio_generate.py b/tests/entrypoints/openai_api/test_serving_audio_generate.py deleted file mode 100644 index fcf22bcefc6..00000000000 --- a/tests/entrypoints/openai_api/test_serving_audio_generate.py +++ /dev/null @@ -1,509 +0,0 @@ -# tests/entrypoints/openai_api/test_serving_audio_generate.py -import logging -from inspect import Signature, signature -from unittest.mock import MagicMock, patch - -import pytest -import torch -from fastapi import FastAPI -from fastapi.testclient import TestClient - -from vllm_omni.entrypoints.openai.protocol.audio import ( - CreateAudio, - OpenAICreateAudioGenerateRequest, -) -from vllm_omni.entrypoints.openai.serving_audio_generate import ( - OmniOpenAIServingAudioGenerate, -) -from vllm_omni.inputs.data import OmniDiffusionSamplingParams -from vllm_omni.outputs import OmniRequestOutput - -pytestmark = [pytest.mark.core_model, pytest.mark.cpu] - -logger = logging.getLogger(__name__) - - -# Helper: create a mock audio output for endpoint tests -def create_mock_audio_output( - request_id: str = "audiogen-mock-123", - sample_rate: int = 44100, - num_samples: int = 44100, - audio_key: str = "audio", -) -> OmniRequestOutput: - """Return an OmniRequestOutput mimicking diffusion audio model output.""" - - audio_tensor = torch.sin(torch.linspace(0, 440 * 2 * torch.pi, num_samples)) - - return OmniRequestOutput.from_diffusion( - request_id=request_id, - images=[], - prompt=None, - metrics={}, - multimodal_output={audio_key: audio_tensor, "sr": sample_rate}, - ) - - -def _make_engine_client(*, audio_key: str = "audio", sample_rate: int = 44100): - """Build a mock engine client producing audio output.""" - mock_engine_client = MagicMock() - mock_engine_client.errored = False - mock_engine_client.model_type = "StableAudioPipeline" - mock_engine_client.default_sampling_params_list = [{}] - - async def mock_generate_fn(*args, **kwargs): - yield create_mock_audio_output( - request_id=kwargs.get("request_id", "audiogen-mock"), - sample_rate=sample_rate, - audio_key=audio_key, - ) - - mock_engine_client.generate = MagicMock(side_effect=mock_generate_fn) - return mock_engine_client - - -def _make_server(engine_client=None): - """Build an OmniOpenAIServingAudioGenerate with mocks.""" - if engine_client is None: - engine_client = _make_engine_client() - - mock_models = MagicMock() - mock_models.is_base_model.return_value = True - - return OmniOpenAIServingAudioGenerate( - engine_client=engine_client, - models=mock_models, - request_logger=MagicMock(), - ) - - -@pytest.fixture -def test_app(): - server = _make_server() - - original_fn = server.create_audio_generate - sig = signature(original_fn) - new_params = [p for name, p in sig.parameters.items() if name != "raw_request"] - new_sig = Signature(parameters=new_params, return_annotation=sig.return_annotation) - - async def patched_create_audio_generate(*args, **kwargs): - return await original_fn(*args, **kwargs) - - patched_create_audio_generate.__signature__ = new_sig - server.create_audio_generate = patched_create_audio_generate - - app = FastAPI() - app.add_api_route( - "/v1/audio/generate", - server.create_audio_generate, - methods=["POST"], - response_model=None, - ) - return app - - -@pytest.fixture -def client(test_app): - return TestClient(test_app) - - -# Request Validation (Pydantic model) -class TestRequestValidation: - """Validate OpenAICreateAudioGenerateRequest pydantic constraints.""" - - def test_valid_minimal_request(self): - req = OpenAICreateAudioGenerateRequest(input="A calm piano melody") - assert req.input == "A calm piano melody" - assert req.response_format == "wav" - assert req.speed == 1.0 - - def test_fields_are_wired_correctly(self): - req = OpenAICreateAudioGenerateRequest( - input="rain sounds", - model="stable-audio", - response_format="flac", - speed=1.5, - audio_length=10.0, - audio_start=2.0, - negative_prompt="noise", - guidance_scale=7.5, - num_inference_steps=100, - seed=42, - ) - assert req.input == "rain sounds" - assert req.model == "stable-audio" - assert req.response_format == "flac" - assert req.speed == 1.5 - assert req.audio_length == 10.0 - assert req.audio_start == 2.0 - assert req.negative_prompt == "noise" - assert req.guidance_scale == 7.5 - assert req.num_inference_steps == 100 - assert req.seed == 42 - - def test_invalid_response_format(self): - with pytest.raises(Exception): - OpenAICreateAudioGenerateRequest(input="test", response_format="invalid_format") - - def test_speed_lower_bound(self): - with pytest.raises(Exception): - OpenAICreateAudioGenerateRequest(input="test", speed=0.1) - - def test_speed_upper_bound(self): - with pytest.raises(Exception): - OpenAICreateAudioGenerateRequest(input="test", speed=5.0) - - def test_speed_at_boundaries(self): - req_low = OpenAICreateAudioGenerateRequest(input="test", speed=0.25) - assert req_low.speed == 0.25 - req_high = OpenAICreateAudioGenerateRequest(input="test", speed=4.0) - assert req_high.speed == 4.0 - - def test_stream_format_sse_rejected(self): - with pytest.raises(Exception): - OpenAICreateAudioGenerateRequest(input="test", stream_format="sse") - - def test_stream_format_audio_accepted(self): - req = OpenAICreateAudioGenerateRequest(input="test", stream_format="audio") - assert req.stream_format == "audio" - - -# Constructor & Class Methods -class TestConstructor: - def test_default_init(self): - server = _make_server() - assert server.diffusion_mode is False - - def test_for_diffusion_factory(self): - engine_client = _make_engine_client() - mock_models = MagicMock() - mock_models.is_base_model.return_value = True - - server = OmniOpenAIServingAudioGenerate.for_diffusion( - engine_client=engine_client, - models=mock_models, - request_logger=MagicMock(), - ) - assert server.diffusion_mode is True - - def test_is_stable_audio_model_true(self): - server = _make_server() - assert server._is_stable_audio_model() is True - - def test_is_stable_audio_model_false(self): - engine = _make_engine_client() - engine.model_type = "SomeOtherModel" - server = _make_server(engine_client=engine) - assert server._is_stable_audio_model() is False - - -# Parameter Wiring — verify request params reach the engine -class TestParameterWiring: - """Ensure request parameters are correctly forwarded to the engine.""" - - @pytest.fixture - def server_and_engine(self): - engine = _make_engine_client() - server = _make_server(engine_client=engine) - return server, engine - - @pytest.mark.asyncio - async def test_prompt_wiring(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="birds chirping") - await server.create_audio_generate(req) - - engine.generate.assert_called_once() - call_kwargs = engine.generate.call_args[1] - assert call_kwargs["prompt"]["prompt"] == "birds chirping" - assert call_kwargs["output_modalities"] == ["audio"] - - @pytest.mark.asyncio - async def test_negative_prompt_wiring(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="a calm ocean", negative_prompt="noise distortion") - await server.create_audio_generate(req) - - call_kwargs = engine.generate.call_args[1] - assert call_kwargs["prompt"]["negative_prompt"] == "noise distortion" - - @pytest.mark.asyncio - async def test_negative_prompt_absent(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="a calm ocean") - await server.create_audio_generate(req) - - call_kwargs = engine.generate.call_args[1] - assert "negative_prompt" not in call_kwargs["prompt"] - - @pytest.mark.asyncio - async def test_guidance_scale_wiring(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test", guidance_scale=12.0) - await server.create_audio_generate(req) - - call_kwargs = engine.generate.call_args[1] - sp = call_kwargs["sampling_params_list"][0] - assert isinstance(sp, OmniDiffusionSamplingParams) - assert sp.guidance_scale == 12.0 - - @pytest.mark.asyncio - async def test_num_inference_steps_wiring(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test", num_inference_steps=200) - await server.create_audio_generate(req) - - sp = engine.generate.call_args[1]["sampling_params_list"][0] - assert sp.num_inference_steps == 200 - - @pytest.mark.asyncio - async def test_seed_creates_generator(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test", seed=42) - - with patch("vllm_omni.entrypoints.openai.serving_audio_generate.torch") as mock_torch: - mock_gen = MagicMock() - mock_gen.manual_seed.return_value = mock_gen - mock_torch.Generator.return_value = mock_gen - - await server.create_audio_generate(req) - - mock_torch.Generator.assert_called_once() - mock_gen.manual_seed.assert_called_once_with(42) - - @pytest.mark.asyncio - async def test_seed_none_skips_generator(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test") - - await server.create_audio_generate(req) - - sp = engine.generate.call_args[1]["sampling_params_list"][0] - assert sp.generator is None - - @pytest.mark.asyncio - async def test_audio_length_wiring(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test", audio_length=10.0, audio_start=2.0) - await server.create_audio_generate(req) - - sp = engine.generate.call_args[1]["sampling_params_list"][0] - assert sp.extra_args["audio_start_in_s"] == 2.0 - assert sp.extra_args["audio_end_in_s"] == 12.0 # start + length - - @pytest.mark.asyncio - async def test_audio_length_default_start(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test", audio_length=5.0) - await server.create_audio_generate(req) - - sp = engine.generate.call_args[1]["sampling_params_list"][0] - assert sp.extra_args["audio_start_in_s"] == 0.0 - assert sp.extra_args["audio_end_in_s"] == 5.0 - - @pytest.mark.asyncio - async def test_no_audio_length_skips_extra_args(self, server_and_engine): - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test") - await server.create_audio_generate(req) - - sp = engine.generate.call_args[1]["sampling_params_list"][0] - assert sp.extra_args == {} - - @pytest.mark.asyncio - async def test_defaults_not_set_when_omitted(self, server_and_engine): - """Guidance scale and num_inference_steps keep dataclass defaults when not in request.""" - server, engine = server_and_engine - req = OpenAICreateAudioGenerateRequest(input="test") - await server.create_audio_generate(req) - - sp = engine.generate.call_args[1]["sampling_params_list"][0] - defaults = OmniDiffusionSamplingParams() - assert sp.guidance_scale == defaults.guidance_scale - assert sp.num_inference_steps == defaults.num_inference_steps - - -# Audio Response Format -class TestAudioResponseFormat: - def test_wav_response(self, client): - payload = {"input": "a gentle rain", "response_format": "wav"} - response = client.post("/v1/audio/generate", json=payload) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/wav" - assert len(response.content) > 0 - - def test_mp3_response(self, client): - payload = {"input": "a gentle rain", "response_format": "mp3"} - response = client.post("/v1/audio/generate", json=payload) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/mpeg" - assert len(response.content) > 0 - - def test_flac_response(self, client): - payload = {"input": "a gentle rain", "response_format": "flac"} - response = client.post("/v1/audio/generate", json=payload) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/flac" - assert len(response.content) > 0 - - def test_invalid_format_rejected(self, client): - payload = {"input": "test", "response_format": "banana"} - response = client.post("/v1/audio/generate", json=payload) - assert response.status_code == 422 - - @patch("vllm_omni.entrypoints.openai.serving_audio_generate.OmniOpenAIServingAudioGenerate.create_audio") - def test_speed_parameter_forwarded(self, mock_create_audio, test_app): - mock_audio_response = MagicMock() - mock_audio_response.audio_data = b"dummy_audio" - mock_audio_response.media_type = "audio/wav" - mock_create_audio.return_value = mock_audio_response - - c = TestClient(test_app) - payload = {"input": "test", "response_format": "wav", "speed": 2.5} - c.post("/v1/audio/generate", json=payload) - - mock_create_audio.assert_called_once() - audio_obj = mock_create_audio.call_args[0][0] - assert isinstance(audio_obj, CreateAudio) - assert audio_obj.speed == 2.5 - - @patch("vllm_omni.entrypoints.openai.serving_audio_generate.OmniOpenAIServingAudioGenerate.create_audio") - def test_sample_rate_from_output(self, mock_create_audio, test_app): - mock_audio_response = MagicMock() - mock_audio_response.audio_data = b"dummy" - mock_audio_response.media_type = "audio/wav" - mock_create_audio.return_value = mock_audio_response - - c = TestClient(test_app) - payload = {"input": "test"} - c.post("/v1/audio/generate", json=payload) - - audio_obj = mock_create_audio.call_args[0][0] - assert audio_obj.sample_rate == 44100 # Stable Audio default - - -# Error Handling -class TestErrorHandling: - @pytest.mark.asyncio - async def test_no_output_returns_error(self): - engine = _make_engine_client() - - async def empty_gen(*args, **kwargs): - return - yield # noqa: unreachable – makes this an async generator - - engine.generate = MagicMock(side_effect=empty_gen) - server = _make_server(engine_client=engine) - req = OpenAICreateAudioGenerateRequest(input="test") - resp = await server.create_audio_generate(req) - - # create_error_response returns an ErrorResponse with .error.message - assert "No output generated" in resp.error.message - - @pytest.mark.asyncio - async def test_no_audio_in_output_returns_error(self): - engine = _make_engine_client() - - async def gen_without_audio(*args, **kwargs): - yield OmniRequestOutput.from_diffusion( - request_id="test", - images=[], - prompt=None, - metrics={}, - multimodal_output={}, # no audio key - ) - - engine.generate = MagicMock(side_effect=gen_without_audio) - server = _make_server(engine_client=engine) - req = OpenAICreateAudioGenerateRequest(input="test") - resp = await server.create_audio_generate(req) - - assert "did not produce audio" in resp.error.message - - @pytest.mark.asyncio - async def test_engine_errored_raises(self): - engine = _make_engine_client() - engine.errored = True - engine.dead_error = RuntimeError("engine is dead") - server = _make_server(engine_client=engine) - - req = OpenAICreateAudioGenerateRequest(input="test") - with pytest.raises(RuntimeError, match="engine is dead"): - await server.create_audio_generate(req) - - @pytest.mark.asyncio - async def test_model_outputs_key_fallback(self): - """Audio data under 'model_outputs' key should be accepted.""" - engine = _make_engine_client(audio_key="model_outputs") - server = _make_server(engine_client=engine) - req = OpenAICreateAudioGenerateRequest(input="test") - resp = await server.create_audio_generate(req) - - # Should succeed and return a Response with audio bytes - assert hasattr(resp, "body") - assert len(resp.body) > 0 - - @pytest.mark.asyncio - async def test_value_error_returns_error_response(self): - engine = _make_engine_client() - - async def gen_value_error(*args, **kwargs): - raise ValueError("bad value") - yield # noqa: unreachable - - engine.generate = MagicMock(side_effect=gen_value_error) - server = _make_server(engine_client=engine) - req = OpenAICreateAudioGenerateRequest(input="test") - resp = await server.create_audio_generate(req) - - assert "bad value" in resp.error.message - - @pytest.mark.asyncio - async def test_generic_exception_returns_error_response(self): - engine = _make_engine_client() - - async def gen_runtime_error(*args, **kwargs): - raise RuntimeError("something went wrong") - yield # noqa: unreachable - - engine.generate = MagicMock(side_effect=gen_runtime_error) - server = _make_server(engine_client=engine) - req = OpenAICreateAudioGenerateRequest(input="test") - resp = await server.create_audio_generate(req) - - assert "Audio generation failed" in resp.error.message - - -# End-to-End via TestClient -class TestAudioGenerateAPI: - def test_basic_success(self, client): - payload = {"input": "ambient forest sounds"} - response = client.post("/v1/audio/generate", json=payload) - assert response.status_code == 200 - assert len(response.content) > 0 - - def test_with_all_params(self, client): - payload = { - "input": "gentle piano", - "response_format": "wav", - "speed": 1.0, - "audio_length": 5.0, - "audio_start": 0.0, - "negative_prompt": "noise", - "guidance_scale": 7.0, - "num_inference_steps": 50, - "seed": 123, - } - response = client.post("/v1/audio/generate", json=payload) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/wav" - - def test_missing_input_rejected(self, client): - payload = {} - response = client.post("/v1/audio/generate", json=payload) - assert response.status_code == 422 - - def test_extra_unknown_fields_ignored(self, client): - payload = {"input": "test", "unknown_field": "value"} - response = client.post("/v1/audio/generate", json=payload) - # Pydantic v2 ignores extra fields by default - assert response.status_code == 200 diff --git a/tests/entrypoints/test_omni_diffusion.py b/tests/entrypoints/test_omni_diffusion.py index 2fbe7a8b42b..90b0fd05326 100644 --- a/tests/entrypoints/test_omni_diffusion.py +++ b/tests/entrypoints/test_omni_diffusion.py @@ -612,7 +612,7 @@ def test_initialize_stage_configs_called_when_none( """Test that stage configs are auto-loaded when stage_configs_path is None.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -687,7 +687,7 @@ def test_generate_raises_on_length_mismatch(monkeypatch: pytest.MonkeyPatch, moc """Test that generate raises ValueError when sampling_params_list length doesn't match.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -742,7 +742,7 @@ def test_generate_pipeline_and_final_outputs(monkeypatch: pytest.MonkeyPatch, mo stage_cfg1["processed_input"] = ["processed-for-stage-1"] def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -848,7 +848,7 @@ def test_generate_pipeline_with_batch_input(monkeypatch: pytest.MonkeyPatch, moc stage_cfg1["stage_id"] = 1 def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -969,7 +969,7 @@ def test_generate_no_final_output_returns_empty( stage_cfg1["final_output"] = False def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1062,7 +1062,7 @@ def test_generate_sampling_params_none_use_default( stage_cfg1["final_output"] = False def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1140,7 +1140,7 @@ def test_wait_for_stages_ready_timeout(monkeypatch: pytest.MonkeyPatch, mocker: """Test that _wait_for_stages_ready handles timeout correctly.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1202,7 +1202,7 @@ def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker """Test that generate handles error messages from stages correctly.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1283,7 +1283,7 @@ def test_close_sends_shutdown_signal(monkeypatch: pytest.MonkeyPatch, mocker: Mo """Test that close() sends shutdown signal to all input queues.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, diff --git a/tests/entrypoints/test_omni_llm.py b/tests/entrypoints/test_omni_llm.py index 0e1745c6833..b11b14c8939 100644 --- a/tests/entrypoints/test_omni_llm.py +++ b/tests/entrypoints/test_omni_llm.py @@ -577,7 +577,7 @@ def test_initialize_stage_configs_called_when_none( """Test that stage configs are auto-loaded when stage_configs_path is None.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -654,7 +654,7 @@ def test_generate_raises_on_length_mismatch(monkeypatch: pytest.MonkeyPatch, moc """Test that generate raises ValueError when sampling_params_list length doesn't match.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -711,7 +711,7 @@ def test_generate_pipeline_and_final_outputs(monkeypatch: pytest.MonkeyPatch, mo stage_cfg1["processed_input"] = ["processed-for-stage-1"] def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -819,7 +819,7 @@ def test_generate_no_final_output_returns_empty( stage_cfg1["final_output"] = False def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -911,7 +911,7 @@ def test_generate_sampling_params_none_use_default( stage_cfg1["final_output"] = False def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -988,7 +988,7 @@ def test_wait_for_stages_ready_timeout(monkeypatch: pytest.MonkeyPatch, mocker: """Test that _wait_for_stages_ready handles timeout correctly.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1052,7 +1052,7 @@ def test_generate_handles_error_messages(monkeypatch: pytest.MonkeyPatch, mocker """Test that generate handles error messages from stages correctly.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, @@ -1135,7 +1135,7 @@ def test_close_sends_shutdown_signal(monkeypatch: pytest.MonkeyPatch, mocker: Mo """Test that close() sends shutdown signal to all input queues.""" def _fake_loader( - config_path: str, + model: str, stage_configs_path: str | None = None, base_engine_args: dict | None = None, default_stage_cfg_factory=None, diff --git a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py index a1eb74eaf18..22dfc06c5a8 100644 --- a/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py +++ b/vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py @@ -356,6 +356,7 @@ def forward( negative_prompt: str | list[str] | None = None, audio_end_in_s: float | None = None, audio_start_in_s: float = 0.0, + num_inference_steps: int = 100, guidance_scale: float = 7.0, num_waveforms_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, @@ -373,6 +374,7 @@ def forward( negative_prompt: Negative prompt for CFG audio_end_in_s: Audio end time in seconds (max ~47s for stable-audio-open-1.0) audio_start_in_s: Audio start time in seconds + num_inference_steps: Number of denoising steps guidance_scale: CFG scale num_waveforms_per_prompt: Number of audio outputs per prompt generator: Random generator for reproducibility @@ -393,7 +395,7 @@ def forward( elif req.prompts: negative_prompt = ["" if isinstance(p, str) else (p.get("negative_prompt") or "") for p in req.prompts] - num_inference_steps = req.sampling_params.num_inference_steps + num_inference_steps = req.sampling_params.num_inference_steps or num_inference_steps if req.sampling_params.guidance_scale_provided: guidance_scale = req.sampling_params.guidance_scale diff --git a/vllm_omni/entrypoints/omni_llm.py b/vllm_omni/entrypoints/omni_llm.py index 91c1c6bed1d..27f470b4b5e 100644 --- a/vllm_omni/entrypoints/omni_llm.py +++ b/vllm_omni/entrypoints/omni_llm.py @@ -28,7 +28,6 @@ load_stage_configs_from_model, load_stage_configs_from_yaml, resolve_model_config_path, - resolve_model_type, ) logger = init_logger(__name__) @@ -85,13 +84,12 @@ def __init__( self.worker_backend = kwargs.get("worker_backend", "multi_process") self.ray_address = kwargs.get("ray_address", None) self.batch_timeout = batch_timeout - self.model_type = resolve_model_type(model) self.log_stats: bool = bool(log_stats) # Load stage configurations if stage_configs_path is None: - self.config_path = resolve_model_config_path(self.model_type) - self.stage_configs = load_stage_configs_from_model(config_path=self.config_path) + self.config_path = resolve_model_config_path(model) + self.stage_configs = load_stage_configs_from_model(model) else: self.config_path = stage_configs_path self.stage_configs = load_stage_configs_from_yaml(stage_configs_path) diff --git a/vllm_omni/entrypoints/openai/api_server.py b/vllm_omni/entrypoints/openai/api_server.py index bb4649bf148..6db97dd2ddb 100644 --- a/vllm_omni/entrypoints/openai/api_server.py +++ b/vllm_omni/entrypoints/openai/api_server.py @@ -87,10 +87,7 @@ encode_image_base64, parse_size, ) -from vllm_omni.entrypoints.openai.protocol.audio import ( - OpenAICreateAudioGenerateRequest, - OpenAICreateSpeechRequest, -) +from vllm_omni.entrypoints.openai.protocol.audio import OpenAICreateSpeechRequest from vllm_omni.entrypoints.openai.protocol.images import ( ImageData, ImageGenerationRequest, @@ -100,7 +97,6 @@ VideoGenerationRequest, VideoGenerationResponse, ) -from vllm_omni.entrypoints.openai.serving_audio_generate import OmniOpenAIServingAudioGenerate from vllm_omni.entrypoints.openai.serving_chat import OmniOpenAIServingChat from vllm_omni.entrypoints.openai.serving_speech import OmniOpenAIServingSpeech from vllm_omni.entrypoints.openai.serving_video import OmniOpenAIServingVideo @@ -179,29 +175,8 @@ class _DiffusionServingModels: provide a lightweight fallback. """ - class _NullModelConfig: - def __getattr__(self, name): - return None - - class _Unsupported: - def __init__(self, name: str): - self.name = name - - def __call__(self, *args, **kwargs): - raise NotImplementedError(f"{self.name} is not supported in diffusion mode") - - def __getattr__(self, attr): - raise NotImplementedError(f"{self.name}.{attr} is not supported in diffusion mode") - def __init__(self, base_model_paths: list[BaseModelPath]) -> None: self._base_model_paths = base_model_paths - self.model_config = self._NullModelConfig() - - def __getattr__(self, name): - """Return a sentinel that raises NotImplementedError if called or - accessed, so any use of unsupported OpenAIServingModels features in - diffusion mode fails loudly with a descriptive message.""" - return self._Unsupported(name) async def show_available_models(self) -> ModelList: return ModelList( @@ -459,10 +434,10 @@ async def omni_init_app_state( # For omni models state.stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None - model_name = served_model_names[0] if served_model_names else args.model # Pure Diffusion mode: use simplified initialization logic if is_pure_diffusion: + model_name = served_model_names[0] if served_model_names else args.model state.vllm_config = None state.diffusion_engine = engine_client state.openai_serving_models = _DiffusionServingModels(base_model_paths) @@ -475,16 +450,6 @@ async def omni_init_app_state( model_name=model_name, ) - # audio related - state.openai_serving_speech = None - state.openai_serving_audio_generate = OmniOpenAIServingAudioGenerate.for_diffusion( - engine_client, - state.openai_serving_models, - request_logger=request_logger, - model_name=model_name, - ) - - # video related diffusion_stage_configs = engine_client.stage_configs if hasattr(engine_client, "stage_configs") else None state.openai_serving_video = OmniOpenAIServingVideo.for_diffusion( diffusion_engine=engine_client, # type: ignore @@ -492,6 +457,7 @@ async def omni_init_app_state( stage_configs=diffusion_stage_configs, ) + state.enable_server_load_tracking = getattr(args, "enable_server_load_tracking", False) state.server_load_metrics = 0 logger.info("Pure diffusion API server initialized for model: %s", model_name) return @@ -766,11 +732,7 @@ async def omni_init_app_state( ) state.openai_serving_speech = OmniOpenAIServingSpeech( - engine_client, state.openai_serving_models, request_logger=request_logger, model_name=model_name - ) - - state.openai_serving_audio_generate = OmniOpenAIServingAudioGenerate( - engine_client, state.openai_serving_models, request_logger=request_logger, model_name=model_name + engine_client, state.openai_serving_models, request_logger=request_logger ) state.openai_serving_video = OmniOpenAIServingVideo( @@ -795,10 +757,6 @@ def Omnispeech(request: Request) -> OmniOpenAIServingSpeech | None: return request.app.state.openai_serving_speech -def OmniAudioGenerate(request: Request) -> OmniOpenAIServingAudioGenerate | None: - return getattr(request.app.state, "openai_serving_audio_generate", None) - - @router.post( "/v1/chat/completions", dependencies=[Depends(validate_json_request)], @@ -909,34 +867,6 @@ async def create_speech(request: OpenAICreateSpeechRequest, raw_request: Request raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e -@router.post( - "/v1/audio/generate", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"audio/*": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def create_audio_generate(request: OpenAICreateAudioGenerateRequest, raw_request: Request): - handler = OmniAudioGenerate(raw_request) - if handler is None: - base_server = getattr(raw_request.app.state, "openai_serving_tokenization", None) - if base_server is None: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND.value, - detail="The model does not support Audio Generate API", - ) - return base_server.create_error_response(message="The model does not support Audio Generate API") - try: - return await handler.create_audio_generate(request, raw_request) - except Exception as e: - raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)) from e - - @router.get( "/v1/audio/voices", responses={ diff --git a/vllm_omni/entrypoints/openai/audio_utils_mixin.py b/vllm_omni/entrypoints/openai/audio_utils_mixin.py index 3023b67ad65..13df32ebe00 100644 --- a/vllm_omni/entrypoints/openai/audio_utils_mixin.py +++ b/vllm_omni/entrypoints/openai/audio_utils_mixin.py @@ -45,10 +45,6 @@ def create_audio(self, audio_obj: CreateAudio) -> AudioResponse: "Only mono (1D) and stereo (2D) are supported." ) - if audio_tensor.ndim == 2 and audio_tensor.shape[0] == 2: - # Convert from [channels, samples] to [samples, channels] - audio_tensor = audio_tensor.T - audio_tensor, sample_rate = self._apply_speed_adjustment(audio_tensor, speed, sample_rate) supported_formats = { diff --git a/vllm_omni/entrypoints/openai/protocol/audio.py b/vllm_omni/entrypoints/openai/protocol/audio.py index 91690cbf0c2..1efab8ebec2 100644 --- a/vllm_omni/entrypoints/openai/protocol/audio.py +++ b/vllm_omni/entrypoints/openai/protocol/audio.py @@ -85,53 +85,6 @@ def validate_streaming_constraints(self) -> "OpenAICreateSpeechRequest": return self -class OpenAICreateAudioGenerateRequest(BaseModel): - """Request model for audio generation via diffusion models (e.g. Stable Audio).""" - - input: str = Field( - description="Text prompt describing the audio to generate", - ) - model: str | None = None - response_format: Literal["wav", "pcm", "flac", "mp3", "aac", "opus"] = "wav" - speed: float | None = Field( - default=1.0, - ge=0.25, - le=4.0, - ) - stream_format: Literal["sse", "audio"] | None = "audio" - audio_length: float | None = Field( - default=None, - description="Audio length in seconds", - ) - audio_start: float | None = Field( - default=0.0, - description="Audio start time in seconds", - ) - negative_prompt: str | None = Field( - default=None, - description="Negative prompt for classifier-free guidance", - ) - guidance_scale: float | None = Field( - default=None, - description="Guidance scale for diffusion models", - ) - num_inference_steps: int | None = Field( - default=None, - description="Number of inference steps", - ) - seed: int | None = Field( - default=None, - description="Random seed for reproducibility", - ) - - @field_validator("stream_format") - @classmethod - def validate_stream_format(cls, v: str) -> str: - if v == "sse": - raise ValueError("'sse' is not a supported stream_format yet. Please use 'audio'.") - return v - - class CreateAudio(BaseModel): audio_tensor: np.ndarray sample_rate: int = 24000 diff --git a/vllm_omni/entrypoints/openai/serving_audio_generate.py b/vllm_omni/entrypoints/openai/serving_audio_generate.py deleted file mode 100644 index bd5c37157a0..00000000000 --- a/vllm_omni/entrypoints/openai/serving_audio_generate.py +++ /dev/null @@ -1,168 +0,0 @@ -import asyncio - -import torch -from fastapi import Request -from fastapi.responses import Response -from vllm.entrypoints.openai.engine.serving import OpenAIServing -from vllm.logger import init_logger -from vllm.utils import random_uuid - -from vllm_omni.entrypoints.openai.audio_utils_mixin import AudioMixin -from vllm_omni.entrypoints.openai.protocol.audio import ( - AudioResponse, - CreateAudio, - OpenAICreateAudioGenerateRequest, -) -from vllm_omni.inputs.data import OmniDiffusionSamplingParams -from vllm_omni.outputs import OmniRequestOutput - -logger = init_logger(__name__) - - -class OmniOpenAIServingAudioGenerate(OpenAIServing, AudioMixin): - """Serving class for audio generation via diffusion models (e.g. Stable Audio).""" - - def __init__(self, *args, **kwargs): - self.model_name = kwargs.pop("model_name", None) - super().__init__(*args, **kwargs) - self.diffusion_mode = False - - @classmethod - def for_diffusion(cls, *args, **kwargs) -> "OmniOpenAIServingAudioGenerate": - """Create an instance configured to run in diffusion mode.""" - instance = cls(*args, **kwargs) - instance.diffusion_mode = True - return instance - - def _is_stable_audio_model(self) -> bool: - return self.engine_client.model_type == "StableAudioPipeline" - - async def create_audio_generate( - self, - request: OpenAICreateAudioGenerateRequest, - raw_request: Request | None = None, - ): - """ - Generate audio using diffusion-based models (e.g. Stable Audio). - - This endpoint is designed for audio generation models as - opposed to TTS models that specifically generate speech. - """ - - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - logger.error("Error with model %s", error_check_ret) - return error_check_ret - - if self.engine_client.errored: - raise self.engine_client.dead_error - - request_id = f"audiogen-{random_uuid()}" - - try: - default_sr = 44100 # Default sample rate for Stable Audio - - # Build prompt for diffusion audio generation - prompt = { - "prompt": request.input, - } - if request.negative_prompt: - prompt["negative_prompt"] = request.negative_prompt - - # Build sampling params for diffusion - sampling_params_list = [OmniDiffusionSamplingParams(num_outputs_per_prompt=1)] - - # Create generator if seed provided - if request.seed is not None: - from vllm_omni.platforms import current_omni_platform - - generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(request.seed) - sampling_params_list[0].generator = generator - - if request.guidance_scale is not None: - sampling_params_list[0].guidance_scale = request.guidance_scale - - if request.num_inference_steps is not None: - sampling_params_list[0].num_inference_steps = request.num_inference_steps - - # Set up audio duration parameters - if request.audio_length is not None: - audio_length = request.audio_length - audio_start = request.audio_start if request.audio_start is not None else 0.0 - audio_end_in_s = audio_start + audio_length - sampling_params_list[0].extra_args = { - "audio_start_in_s": audio_start, - "audio_end_in_s": audio_end_in_s, - } - - logger.info( - "Audio generation request %s: prompt=%r", - request_id, - request.input[:50] + "..." if len(request.input) > 50 else request.input, - ) - - generator = self.engine_client.generate( - prompt=prompt, - request_id=request_id, - sampling_params_list=sampling_params_list, - output_modalities=["audio"], - ) - - final_output: OmniRequestOutput | None = None - async for res in generator: - final_output = res - - if final_output is None: - return self.create_error_response("No output generated from the model.") - - # Extract audio from output - audio_output = None - if hasattr(final_output, "multimodal_output") and final_output.multimodal_output: - audio_output = final_output.multimodal_output - if not audio_output and hasattr(final_output, "request_output"): - if final_output.request_output and hasattr(final_output.request_output, "multimodal_output"): - audio_output = final_output.request_output.multimodal_output - - # Check for audio data using either "audio" or "model_outputs" key - audio_key = None - if audio_output: - if "audio" in audio_output: - audio_key = "audio" - elif "model_outputs" in audio_output: - audio_key = "model_outputs" - - if not audio_output or audio_key is None: - return self.create_error_response("Audio generation model did not produce audio output.") - - audio_tensor = audio_output[audio_key] - sample_rate = audio_output.get("sr", default_sr) - if hasattr(sample_rate, "item"): - sample_rate = sample_rate.item() - - # Convert tensor to numpy - if hasattr(audio_tensor, "float"): - audio_tensor = audio_tensor.float().detach().cpu().numpy() - - # Squeeze batch dimension if present, but preserve channel dimension for stereo - if audio_tensor.ndim > 1: - audio_tensor = audio_tensor.squeeze() - - audio_obj = CreateAudio( - audio_tensor=audio_tensor, - sample_rate=int(sample_rate), - response_format=request.response_format or "wav", - speed=request.speed or 1.0, - stream_format=request.stream_format, - base64_encode=False, - ) - - audio_response: AudioResponse = self.create_audio(audio_obj) - return Response(content=audio_response.audio_data, media_type=audio_response.media_type) - - except asyncio.CancelledError: - return self.create_error_response("Client disconnected") - except ValueError as e: - return self.create_error_response(e) - except Exception as e: - logger.exception("Audio generation failed: %s", e) - return self.create_error_response(f"Audio generation failed: {e}") diff --git a/vllm_omni/entrypoints/openai/serving_speech.py b/vllm_omni/entrypoints/openai/serving_speech.py index 75cf956fa34..7bcf75ace9d 100644 --- a/vllm_omni/entrypoints/openai/serving_speech.py +++ b/vllm_omni/entrypoints/openai/serving_speech.py @@ -84,7 +84,6 @@ def _validate_path_within_directory(file_path: Path, directory: Path) -> bool: class OmniOpenAIServingSpeech(OpenAIServing, AudioMixin): def __init__(self, *args, **kwargs): - self.model_name = kwargs.pop("model_name", None) super().__init__(*args, **kwargs) # Initialize uploaded speakers storage speech_voice_samples_dir = os.environ.get("SPEECH_VOICE_SAMPLES", "/tmp/voice_samples") @@ -757,8 +756,6 @@ async def create_speech( request_id = f"speech-{random_uuid()}" try: - sampling_params_list = self.engine_client.default_sampling_params_list - default_sr = 24000 # Default sample rate for TTS models if self._is_tts: # Validate TTS parameters validation_error = self._validate_tts_request(request) @@ -784,6 +781,8 @@ async def create_speech( tts_params.get("task_type", ["unknown"])[0], ) + sampling_params_list = self.engine_client.default_sampling_params_list + generator = self.engine_client.generate( prompt=prompt, request_id=request_id, @@ -810,7 +809,7 @@ async def create_speech( return self.create_error_response("TTS model did not produce audio output.") audio_tensor = audio_output[audio_key] - sr_raw = audio_output.get("sr", default_sr) + sr_raw = audio_output.get("sr", 24000) sr_val = sr_raw[-1] if isinstance(sr_raw, list) and sr_raw else sr_raw sample_rate = sr_val.item() if hasattr(sr_val, "item") else int(sr_val) diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index beec074d89a..69ce73fc47a 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -169,46 +169,22 @@ def _convert_dataclasses_to_dict(obj: Any) -> Any: return obj -def resolve_model_config_path(model_type: str) -> str | None: - """Resolve the stage config file path from the model type. +def resolve_model_config_path(model: str) -> str: + """Resolve the stage config file path from the model name. Resolves stage configuration path based on the model type and device type. First tries to find a device-specific YAML file from stage_configs/{device_type}/ directory. If not found, falls back to the default config file. Args: - model_type: Model type string + model: Model name or path (used to determine model_type) Returns: - String path to the stage configuration file if found, None otherwise - """ - default_config_path = current_omni_platform.get_default_stage_config_path() - config_file_name = f"{model_type}.yaml" - complete_config_path = PROJECT_ROOT / default_config_path / config_file_name - if os.path.exists(complete_config_path): - return str(complete_config_path) - - # Fall back to default config - stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" - stage_config_path = PROJECT_ROOT / stage_config_file - if not os.path.exists(stage_config_path): - return None - return str(stage_config_path) - - -def resolve_model_type(model: str) -> str: - """Resolve the model type from the model name. - - Args: - model: Model name or path - - Returns: - Model type string (e.g. ``"Qwen3TTSForConditionalGeneration"``, - ``"StableAudioPipeline"``). + String path to the stage configuration file Raises: - ValueError: If the model type cannot be determined from any - available configuration file. + ValueError: If model_type cannot be determined + FileNotFoundError: If no stage config file exists for the model type """ # Try to get config from standard transformers format first try: @@ -241,26 +217,42 @@ def resolve_model_type(model: str) -> str: f"Please ensure the model has proper configuration files with 'model_type' field" ) - return model_type + default_config_path = current_omni_platform.get_default_stage_config_path() + model_type_str = f"{model_type}.yaml" + complete_config_path = PROJECT_ROOT / default_config_path / model_type_str + if os.path.exists(complete_config_path): + return str(complete_config_path) + + # Fall back to default config + stage_config_file = f"vllm_omni/model_executor/stage_configs/{model_type}.yaml" + stage_config_path = PROJECT_ROOT / stage_config_file + if not os.path.exists(stage_config_path): + return None + return str(stage_config_path) + +def load_stage_configs_from_model(model: str, base_engine_args: dict | None = None) -> list: + """Load stage configurations from model's default config file. -def load_stage_configs_from_model(config_path: str | None, base_engine_args: dict | None = None) -> list: - """Load stage configurations from a resolved config file path. + Loads stage configurations based on the model type and device type. + First tries to load a device-specific YAML file from stage_configs/{device_type}/ + directory. If not found, falls back to the default config file. Args: - config_path: Path to the YAML configuration file, or None. - When None, returns an empty list. - base_engine_args: Optional engine arguments to merge with stage configs. + model: Model name or path (used to determine model_type) Returns: - List of stage configuration dictionaries, or empty list if - config_path is None. + List of stage configuration dictionaries + + Raises: + FileNotFoundError: If no stage config file exists for the model type """ if base_engine_args is None: base_engine_args = {} - if config_path is None: + stage_config_path = resolve_model_config_path(model) + if stage_config_path is None: return [] - stage_configs = load_stage_configs_from_yaml(config_path=config_path, base_engine_args=base_engine_args) + stage_configs = load_stage_configs_from_yaml(config_path=stage_config_path, base_engine_args=base_engine_args) return stage_configs @@ -314,10 +306,9 @@ def load_and_resolve_stage_configs( Returns: Tuple of (config_path, stage_configs) """ - model_type = resolve_model_type(model) if stage_configs_path is None: - config_path = resolve_model_config_path(model_type) - stage_configs = load_stage_configs_from_model(config_path, base_engine_args=kwargs) + config_path = resolve_model_config_path(model) + stage_configs = load_stage_configs_from_model(model, base_engine_args=kwargs) if not stage_configs: if default_stage_cfg_factory is not None: default_stage_cfg = default_stage_cfg_factory()