Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/guides/audio.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ Transcribe audio to text (OpenAI Whisper API compatible).
- `language`: Language code (optional, auto-detected)
- `response_format`: `json` or `text`

**Limits:**
- Default upload cap: 25 MiB
- Override with `--max-audio-upload-mb`

**Example:**
```bash
curl http://localhost:8000/v1/audio/transcriptions \
Expand All @@ -218,6 +222,10 @@ Generate speech from text (OpenAI TTS API compatible).
- `speed`: Speech speed (0.5 to 2.0)
- `response_format`: `wav`, `mp3`

**Limits:**
- Default input cap: 4096 characters
- Override with `--max-tts-input-chars`

**Example:**
```bash
curl http://localhost:8000/v1/audio/speech \
Expand Down
2 changes: 2 additions & 0 deletions docs/reference/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ vllm-mlx serve <model> [options]
| `--max-num-seqs` | Max concurrent sequences | 256 |
| `--default-temperature` | Default temperature when not specified in request | None |
| `--default-top-p` | Default top_p when not specified in request | None |
| `--max-audio-upload-mb` | Maximum uploaded audio size for `/v1/audio/transcriptions` | 25 |
| `--max-tts-input-chars` | Maximum text length accepted by `/v1/audio/speech` | 4096 |
| `--reasoning-parser` | Parser for reasoning models (`qwen3`, `deepseek_r1`) | None |
| `--embedding-model` | Pre-load an embedding model at startup | None |
| `--enable-auto-tool-choice` | Enable automatic tool calling | False |
Expand Down
2 changes: 2 additions & 0 deletions docs/reference/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
| `--rate-limit` | Requests per minute per client (0 = disabled) | `0` |
| `--timeout` | Request timeout in seconds | `300` |
| `--enable-metrics` | Expose Prometheus metrics on `/metrics` | `false` |
| `--max-audio-upload-mb` | Maximum uploaded audio size for `/v1/audio/transcriptions` | `25` |
| `--max-tts-input-chars` | Maximum text length accepted by `/v1/audio/speech` | `4096` |

### Batching Options

Expand Down
91 changes: 91 additions & 0 deletions tests/test_audio_limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for audio endpoint resource limits."""

from pathlib import Path

import pytest
from fastapi import HTTPException

from vllm_mlx.audio_limits import (
DEFAULT_MAX_AUDIO_UPLOAD_MB,
DEFAULT_MAX_TTS_INPUT_CHARS,
save_upload_with_limit,
validate_tts_input_length,
)


class FakeUpload:
def __init__(self, chunks: list[bytes], filename: str = "audio.wav"):
self._chunks = list(chunks)
self.filename = filename

async def read(self, _size: int = -1) -> bytes:
if not self._chunks:
return b""
return self._chunks.pop(0)


class TestAudioUploadLimits:
@pytest.mark.asyncio
async def test_save_upload_with_limit_writes_file(self):
upload = FakeUpload([b"a" * 8, b"b" * 4])

path = await save_upload_with_limit(upload, max_bytes=32)

try:
assert Path(path).read_bytes() == b"a" * 8 + b"b" * 4
finally:
Path(path).unlink(missing_ok=True)

@pytest.mark.asyncio
async def test_save_upload_with_limit_rejects_oversize_and_cleans_up(self):
upload = FakeUpload([b"a" * 16, b"b" * 16, b"c"])

with pytest.raises(HTTPException) as exc_info:
await save_upload_with_limit(upload, max_bytes=32)

assert exc_info.value.status_code == 413
assert "Audio upload too large" in exc_info.value.detail


class TestTTSInputLimits:
def test_validate_tts_input_length_accepts_short_text(self):
validate_tts_input_length("hello", max_chars=16)

def test_validate_tts_input_length_rejects_oversized_text(self):
with pytest.raises(HTTPException) as exc_info:
validate_tts_input_length("x" * 17, max_chars=16)

assert exc_info.value.status_code == 413
assert "TTS input too long" in exc_info.value.detail


class TestAudioLimitParsers:
def test_top_level_cli_exposes_audio_limit_flags(self):
from vllm_mlx.cli import create_parser

parser = create_parser()
args = parser.parse_args(
[
"serve",
"mlx-community/Llama-3.2-3B-Instruct-4bit",
"--max-audio-upload-mb",
"12",
"--max-tts-input-chars",
"2048",
]
)

assert args.max_audio_upload_mb == 12
assert args.max_tts_input_chars == 2048

def test_standalone_server_parser_defaults(self):
from vllm_mlx.server import create_parser

parser = create_parser()
args = parser.parse_args(
["--model", "mlx-community/Llama-3.2-3B-Instruct-4bit"]
)

assert args.max_audio_upload_mb == DEFAULT_MAX_AUDIO_UPLOAD_MB
assert args.max_tts_input_chars == DEFAULT_MAX_TTS_INPUT_CHARS
72 changes: 72 additions & 0 deletions vllm_mlx/audio_limits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
"""Resource limits for optional audio endpoints."""

import os
import tempfile
from pathlib import Path
from typing import Protocol

from fastapi import HTTPException

DEFAULT_MAX_AUDIO_UPLOAD_MB = 25
DEFAULT_MAX_AUDIO_UPLOAD_BYTES = DEFAULT_MAX_AUDIO_UPLOAD_MB * 1024 * 1024
DEFAULT_MAX_TTS_INPUT_CHARS = 4096
UPLOAD_CHUNK_SIZE = 1024 * 1024


class AsyncReadableUpload(Protocol):
filename: str | None

async def read(self, size: int = -1) -> bytes: ...


async def save_upload_with_limit(
file: AsyncReadableUpload,
*,
max_bytes: int,
default_suffix: str = ".wav",
chunk_size: int = UPLOAD_CHUNK_SIZE,
) -> str:
"""
Stream an uploaded file to disk while enforcing a hard byte limit.

This prevents large audio uploads from being buffered entirely in memory.
"""
suffix = Path(file.filename or "").suffix or default_suffix
total_bytes = 0

with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp_path = tmp.name
try:
while True:
chunk = await file.read(chunk_size)
if not chunk:
break
total_bytes += len(chunk)
if total_bytes > max_bytes:
raise HTTPException(
status_code=413,
detail=(
f"Audio upload too large: {total_bytes} bytes exceeds "
f"the configured limit of {max_bytes} bytes."
),
)
tmp.write(chunk)
except Exception:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
raise

return tmp_path


def validate_tts_input_length(text: str, *, max_chars: int) -> None:
"""Reject oversized TTS requests before synthesis starts."""
if len(text) > max_chars:
raise HTTPException(
status_code=413,
detail=(
f"TTS input too long: {len(text)} characters exceeds the configured "
f"limit of {max_chars} characters."
),
)
18 changes: 18 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def serve_command(args):
server._default_temperature = args.default_temperature
if args.default_top_p is not None:
server._default_top_p = args.default_top_p
server._max_audio_upload_bytes = args.max_audio_upload_mb * 1024 * 1024
server._max_tts_input_chars = args.max_tts_input_chars

# Configure reasoning parser
if args.reasoning_parser:
Expand Down Expand Up @@ -120,6 +122,10 @@ def serve_command(args):
print(f" Reasoning: ENABLED (parser: {args.reasoning_parser})")
else:
print(" Reasoning: Use --reasoning-parser to enable")
print(
f" Audio upload limit: {args.max_audio_upload_mb} MiB, "
f"TTS input limit: {args.max_tts_input_chars} chars"
)
print("=" * 60)

# Pre-download model with retry/timeout
Expand Down Expand Up @@ -890,6 +896,18 @@ def create_parser() -> argparse.ArgumentParser:
action="store_true",
help="Expose Prometheus metrics on /metrics (disabled by default)",
)
serve_parser.add_argument(
"--max-audio-upload-mb",
type=int,
default=25,
help="Maximum size of uploaded audio files in MiB (default: 25)",
)
serve_parser.add_argument(
"--max-tts-input-chars",
type=int,
default=4096,
help="Maximum number of characters accepted by /v1/audio/speech (default: 4096)",
)
# Tool calling options
serve_parser.add_argument(
"--enable-auto-tool-choice",
Expand Down
41 changes: 35 additions & 6 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import os
import re
import secrets
import tempfile
import threading
import time
import uuid
Expand Down Expand Up @@ -138,6 +137,13 @@
extract_multimodal_content,
is_mllm_model, # noqa: F401
)
from .audio_limits import (
DEFAULT_MAX_AUDIO_UPLOAD_BYTES,
DEFAULT_MAX_AUDIO_UPLOAD_MB,
DEFAULT_MAX_TTS_INPUT_CHARS,
save_upload_with_limit,
validate_tts_input_length,
)
from .engine import BaseEngine, BatchedEngine, GenerationOutput, SimpleEngine
from .endpoint_model_policies import (
resolve_embedding_model_name,
Expand All @@ -161,6 +167,8 @@
_default_temperature: float | None = None # Set via --default-temperature
_default_top_p: float | None = None # Set via --default-top-p
_metrics_enabled = False
_max_audio_upload_bytes: int = DEFAULT_MAX_AUDIO_UPLOAD_BYTES
_max_tts_input_chars: int = DEFAULT_MAX_TTS_INPUT_CHARS

_FALLBACK_TEMPERATURE = 0.7
_FALLBACK_TOP_P = 0.9
Expand Down Expand Up @@ -2045,11 +2053,12 @@ async def create_transcription(
_stt_engine = STTEngine(model_name)
_stt_engine.load()

# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
# Stream uploaded file to disk under a hard size cap.
tmp_path = await save_upload_with_limit(
file,
max_bytes=_max_audio_upload_bytes,
default_suffix=".wav",
)

try:
result = _stt_engine.transcribe(tmp_path, language=language)
Expand Down Expand Up @@ -2106,6 +2115,7 @@ async def create_speech(
from .audio.tts import TTSEngine # Lazy import - optional feature

model_name = resolve_tts_model_name(model)
validate_tts_input_length(input, max_chars=_max_tts_input_chars)

# Load engine if needed
if _tts_engine is None or _tts_engine.model_name != model_name:
Expand Down Expand Up @@ -4241,6 +4251,7 @@ def main():
# Set global configuration
global _api_key, _default_timeout, _rate_limiter, _metrics_enabled
global _default_temperature, _default_top_p
global _max_audio_upload_bytes, _max_tts_input_chars
_api_key = args.api_key
_default_timeout = args.timeout
_metrics_enabled = args.enable_metrics
Expand All @@ -4249,6 +4260,8 @@ def main():
_default_temperature = args.default_temperature
if args.default_top_p is not None:
_default_top_p = args.default_top_p
_max_audio_upload_bytes = args.max_audio_upload_mb * 1024 * 1024
_max_tts_input_chars = args.max_tts_input_chars

# Configure rate limiter
if args.rate_limit > 0:
Expand Down Expand Up @@ -4278,6 +4291,10 @@ def main():
logger.warning(" Remote code loading: ENABLED (--trust-remote-code)")
else:
logger.info(" Remote code loading: DISABLED (default)")
logger.info(
f" Audio upload limit: {args.max_audio_upload_mb} MiB, "
f"TTS input limit: {args.max_tts_input_chars} chars"
)
logger.info("=" * 60)

# Set MCP config for lifespan
Expand Down Expand Up @@ -4426,6 +4443,18 @@ def create_parser() -> argparse.ArgumentParser:
default=None,
help="Default top_p for generation when not specified in request",
)
parser.add_argument(
"--max-audio-upload-mb",
type=int,
default=DEFAULT_MAX_AUDIO_UPLOAD_MB,
help="Maximum size of uploaded audio files in MiB (default: 25)",
)
parser.add_argument(
"--max-tts-input-chars",
type=int,
default=DEFAULT_MAX_TTS_INPUT_CHARS,
help="Maximum number of characters accepted by /v1/audio/speech (default: 4096)",
)
return parser


Expand Down
Loading