diff --git a/docs/guides/audio.md b/docs/guides/audio.md index 2c26a8fe5..9f5b80769 100644 --- a/docs/guides/audio.md +++ b/docs/guides/audio.md @@ -200,6 +200,10 @@ Transcribe audio to text (OpenAI Whisper API compatible). - `language`: Language code (optional, auto-detected) - `response_format`: `json` or `text` +**Limits:** +- Default upload cap: 25 MiB +- Override with `--max-audio-upload-mb` + **Example:** ```bash curl http://localhost:8000/v1/audio/transcriptions \ @@ -218,6 +222,10 @@ Generate speech from text (OpenAI TTS API compatible). - `speed`: Speech speed (0.5 to 2.0) - `response_format`: `wav`, `mp3` +**Limits:** +- Default input cap: 4096 characters +- Override with `--max-tts-input-chars` + **Example:** ```bash curl http://localhost:8000/v1/audio/speech \ diff --git a/docs/reference/cli.md b/docs/reference/cli.md index 85c8d9c35..949ff8889 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -42,6 +42,8 @@ vllm-mlx serve [options] | `--max-num-seqs` | Max concurrent sequences | 256 | | `--default-temperature` | Default temperature when not specified in request | None | | `--default-top-p` | Default top_p when not specified in request | None | +| `--max-audio-upload-mb` | Maximum uploaded audio size for `/v1/audio/transcriptions` | 25 | +| `--max-tts-input-chars` | Maximum text length accepted by `/v1/audio/speech` | 4096 | | `--reasoning-parser` | Parser for reasoning models (`qwen3`, `deepseek_r1`) | None | | `--embedding-model` | Pre-load an embedding model at startup | None | | `--enable-auto-tool-choice` | Enable automatic tool calling | False | diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index dcdff9d78..53df4c04a 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -20,6 +20,8 @@ | `--rate-limit` | Requests per minute per client (0 = disabled) | `0` | | `--timeout` | Request timeout in seconds | `300` | | `--enable-metrics` | Expose Prometheus metrics on `/metrics` | `false` | +| `--max-audio-upload-mb` | Maximum uploaded audio size for `/v1/audio/transcriptions` | `25` | +| `--max-tts-input-chars` | Maximum text length accepted by `/v1/audio/speech` | `4096` | ### Batching Options diff --git a/tests/test_audio_limits.py b/tests/test_audio_limits.py new file mode 100644 index 000000000..596ebcfd6 --- /dev/null +++ b/tests/test_audio_limits.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for audio endpoint resource limits.""" + +from pathlib import Path + +import pytest +from fastapi import HTTPException + +from vllm_mlx.audio_limits import ( + DEFAULT_MAX_AUDIO_UPLOAD_MB, + DEFAULT_MAX_TTS_INPUT_CHARS, + save_upload_with_limit, + validate_tts_input_length, +) + + +class FakeUpload: + def __init__(self, chunks: list[bytes], filename: str = "audio.wav"): + self._chunks = list(chunks) + self.filename = filename + + async def read(self, _size: int = -1) -> bytes: + if not self._chunks: + return b"" + return self._chunks.pop(0) + + +class TestAudioUploadLimits: + @pytest.mark.asyncio + async def test_save_upload_with_limit_writes_file(self): + upload = FakeUpload([b"a" * 8, b"b" * 4]) + + path = await save_upload_with_limit(upload, max_bytes=32) + + try: + assert Path(path).read_bytes() == b"a" * 8 + b"b" * 4 + finally: + Path(path).unlink(missing_ok=True) + + @pytest.mark.asyncio + async def test_save_upload_with_limit_rejects_oversize_and_cleans_up(self): + upload = FakeUpload([b"a" * 16, b"b" * 16, b"c"]) + + with pytest.raises(HTTPException) as exc_info: + await save_upload_with_limit(upload, max_bytes=32) + + assert exc_info.value.status_code == 413 + assert "Audio upload too large" in exc_info.value.detail + + +class TestTTSInputLimits: + def test_validate_tts_input_length_accepts_short_text(self): + validate_tts_input_length("hello", max_chars=16) + + def test_validate_tts_input_length_rejects_oversized_text(self): + with pytest.raises(HTTPException) as exc_info: + validate_tts_input_length("x" * 17, max_chars=16) + + assert exc_info.value.status_code == 413 + assert "TTS input too long" in exc_info.value.detail + + +class TestAudioLimitParsers: + def test_top_level_cli_exposes_audio_limit_flags(self): + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + [ + "serve", + "mlx-community/Llama-3.2-3B-Instruct-4bit", + "--max-audio-upload-mb", + "12", + "--max-tts-input-chars", + "2048", + ] + ) + + assert args.max_audio_upload_mb == 12 + assert args.max_tts_input_chars == 2048 + + def test_standalone_server_parser_defaults(self): + from vllm_mlx.server import create_parser + + parser = create_parser() + args = parser.parse_args( + ["--model", "mlx-community/Llama-3.2-3B-Instruct-4bit"] + ) + + assert args.max_audio_upload_mb == DEFAULT_MAX_AUDIO_UPLOAD_MB + assert args.max_tts_input_chars == DEFAULT_MAX_TTS_INPUT_CHARS diff --git a/vllm_mlx/audio_limits.py b/vllm_mlx/audio_limits.py new file mode 100644 index 000000000..f2515f714 --- /dev/null +++ b/vllm_mlx/audio_limits.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Resource limits for optional audio endpoints.""" + +import os +import tempfile +from pathlib import Path +from typing import Protocol + +from fastapi import HTTPException + +DEFAULT_MAX_AUDIO_UPLOAD_MB = 25 +DEFAULT_MAX_AUDIO_UPLOAD_BYTES = DEFAULT_MAX_AUDIO_UPLOAD_MB * 1024 * 1024 +DEFAULT_MAX_TTS_INPUT_CHARS = 4096 +UPLOAD_CHUNK_SIZE = 1024 * 1024 + + +class AsyncReadableUpload(Protocol): + filename: str | None + + async def read(self, size: int = -1) -> bytes: ... + + +async def save_upload_with_limit( + file: AsyncReadableUpload, + *, + max_bytes: int, + default_suffix: str = ".wav", + chunk_size: int = UPLOAD_CHUNK_SIZE, +) -> str: + """ + Stream an uploaded file to disk while enforcing a hard byte limit. + + This prevents large audio uploads from being buffered entirely in memory. + """ + suffix = Path(file.filename or "").suffix or default_suffix + total_bytes = 0 + + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + tmp_path = tmp.name + try: + while True: + chunk = await file.read(chunk_size) + if not chunk: + break + total_bytes += len(chunk) + if total_bytes > max_bytes: + raise HTTPException( + status_code=413, + detail=( + f"Audio upload too large: {total_bytes} bytes exceeds " + f"the configured limit of {max_bytes} bytes." + ), + ) + tmp.write(chunk) + except Exception: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + raise + + return tmp_path + + +def validate_tts_input_length(text: str, *, max_chars: int) -> None: + """Reject oversized TTS requests before synthesis starts.""" + if len(text) > max_chars: + raise HTTPException( + status_code=413, + detail=( + f"TTS input too long: {len(text)} characters exceeds the configured " + f"limit of {max_chars} characters." + ), + ) diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index e2c939895..162cbb96d 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -67,6 +67,8 @@ def serve_command(args): server._default_temperature = args.default_temperature if args.default_top_p is not None: server._default_top_p = args.default_top_p + server._max_audio_upload_bytes = args.max_audio_upload_mb * 1024 * 1024 + server._max_tts_input_chars = args.max_tts_input_chars # Configure reasoning parser if args.reasoning_parser: @@ -120,6 +122,10 @@ def serve_command(args): print(f" Reasoning: ENABLED (parser: {args.reasoning_parser})") else: print(" Reasoning: Use --reasoning-parser to enable") + print( + f" Audio upload limit: {args.max_audio_upload_mb} MiB, " + f"TTS input limit: {args.max_tts_input_chars} chars" + ) print("=" * 60) # Pre-download model with retry/timeout @@ -890,6 +896,18 @@ def create_parser() -> argparse.ArgumentParser: action="store_true", help="Expose Prometheus metrics on /metrics (disabled by default)", ) + serve_parser.add_argument( + "--max-audio-upload-mb", + type=int, + default=25, + help="Maximum size of uploaded audio files in MiB (default: 25)", + ) + serve_parser.add_argument( + "--max-tts-input-chars", + type=int, + default=4096, + help="Maximum number of characters accepted by /v1/audio/speech (default: 4096)", + ) # Tool calling options serve_parser.add_argument( "--enable-auto-tool-choice", diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 95c8c610e..eb4b64a83 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -45,7 +45,6 @@ import os import re import secrets -import tempfile import threading import time import uuid @@ -138,6 +137,13 @@ extract_multimodal_content, is_mllm_model, # noqa: F401 ) +from .audio_limits import ( + DEFAULT_MAX_AUDIO_UPLOAD_BYTES, + DEFAULT_MAX_AUDIO_UPLOAD_MB, + DEFAULT_MAX_TTS_INPUT_CHARS, + save_upload_with_limit, + validate_tts_input_length, +) from .engine import BaseEngine, BatchedEngine, GenerationOutput, SimpleEngine from .endpoint_model_policies import ( resolve_embedding_model_name, @@ -161,6 +167,8 @@ _default_temperature: float | None = None # Set via --default-temperature _default_top_p: float | None = None # Set via --default-top-p _metrics_enabled = False +_max_audio_upload_bytes: int = DEFAULT_MAX_AUDIO_UPLOAD_BYTES +_max_tts_input_chars: int = DEFAULT_MAX_TTS_INPUT_CHARS _FALLBACK_TEMPERATURE = 0.7 _FALLBACK_TOP_P = 0.9 @@ -2045,11 +2053,12 @@ async def create_transcription( _stt_engine = STTEngine(model_name) _stt_engine.load() - # Save uploaded file temporarily - with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: - content = await file.read() - tmp.write(content) - tmp_path = tmp.name + # Stream uploaded file to disk under a hard size cap. + tmp_path = await save_upload_with_limit( + file, + max_bytes=_max_audio_upload_bytes, + default_suffix=".wav", + ) try: result = _stt_engine.transcribe(tmp_path, language=language) @@ -2106,6 +2115,7 @@ async def create_speech( from .audio.tts import TTSEngine # Lazy import - optional feature model_name = resolve_tts_model_name(model) + validate_tts_input_length(input, max_chars=_max_tts_input_chars) # Load engine if needed if _tts_engine is None or _tts_engine.model_name != model_name: @@ -4241,6 +4251,7 @@ def main(): # Set global configuration global _api_key, _default_timeout, _rate_limiter, _metrics_enabled global _default_temperature, _default_top_p + global _max_audio_upload_bytes, _max_tts_input_chars _api_key = args.api_key _default_timeout = args.timeout _metrics_enabled = args.enable_metrics @@ -4249,6 +4260,8 @@ def main(): _default_temperature = args.default_temperature if args.default_top_p is not None: _default_top_p = args.default_top_p + _max_audio_upload_bytes = args.max_audio_upload_mb * 1024 * 1024 + _max_tts_input_chars = args.max_tts_input_chars # Configure rate limiter if args.rate_limit > 0: @@ -4278,6 +4291,10 @@ def main(): logger.warning(" Remote code loading: ENABLED (--trust-remote-code)") else: logger.info(" Remote code loading: DISABLED (default)") + logger.info( + f" Audio upload limit: {args.max_audio_upload_mb} MiB, " + f"TTS input limit: {args.max_tts_input_chars} chars" + ) logger.info("=" * 60) # Set MCP config for lifespan @@ -4426,6 +4443,18 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Default top_p for generation when not specified in request", ) + parser.add_argument( + "--max-audio-upload-mb", + type=int, + default=DEFAULT_MAX_AUDIO_UPLOAD_MB, + help="Maximum size of uploaded audio files in MiB (default: 25)", + ) + parser.add_argument( + "--max-tts-input-chars", + type=int, + default=DEFAULT_MAX_TTS_INPUT_CHARS, + help="Maximum number of characters accepted by /v1/audio/speech (default: 4096)", + ) return parser