Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ def _read_requirements(filename: str) -> list[str]:
"soundfile",
"mistral_common[audio]",
"av",
"torchcodec",
Copy link
Member

@Isotr0py Isotr0py Mar 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried that torchcodec will break audio support on GB200 + aarch64 CPU, because it only distributes x86_64 manylinux wheels (https://pypi.org/project/torchcodec/#files).

I opened #37061 to revert this PR and use pyav for video fallback instead.

Copy link
Contributor Author

@seanmamasde seanmamasde Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually investigated this a bit back:

| lib           | in-process? | mp4/m4a/webm using bytesio                            | new dep |
| ------------- | ----------- | ----------------------------------------------------- | ------- |
| ffmpeg-python | no          | using pipe, but still subprocess                      | yes     |
| pydub         | no          | using pipe, but still subprocess                      | yes     |
| soundfile     | yes         | libsndfile doesn't support mp4/m4a/webm               | no      |
| PyAV (av)     | yes         | av.open(BytesIO(...)) should work                     | yes     |
| torchaudio    | yes         | torchaudio.load(BytesIO(...), format=...) should work | no      |

At the time of implementation, torchaudio seems like the best bet since it doesn't introduce extra deps and is an in-process conversion (as opposed to tempfile, subprocess w/ ffmpeg). But it seems that starting with torchaudio v2.9.0+ it uses torchcodec for torchaudio.save() and torchaudio.load()

], # Required for audio processing
"video": [], # Kept for backwards compatibility
"flashinfer": [], # Kept for backwards compatibility
Expand Down
27 changes: 7 additions & 20 deletions vllm/entrypoints/openai/speech_to_text/speech_to_text.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
import zlib
Expand All @@ -11,7 +10,6 @@

import numpy as np
from fastapi import Request
from soundfile import LibsndfileError
from transformers import PreTrainedTokenizerBase

import vllm.envs as envs
Expand All @@ -37,6 +35,7 @@
TranslationSegment,
TranslationStreamResponse,
)
from vllm.entrypoints.openai.speech_to_text.utils import load_audio_bytes
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
Expand All @@ -56,14 +55,6 @@
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]

# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# 1 = unrecognised format (file is not a supported audio container)
# 3 = malformed file (corrupt or structurally invalid audio)
# 4 = unsupported encoding (codec not supported by this libsndfile build)
_BAD_SF_CODES = {1, 3, 4}

SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
Expand Down Expand Up @@ -202,16 +193,12 @@ async def _preprocess_speech_to_text(
value=len(audio_data) / 1024**2,
)

with io.BytesIO(audio_data) as bytes_:
try:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
except LibsndfileError as exc:
# Distinguish client errors (invalid audio) from server errors
if exc.code in _BAD_SF_CODES:
raise ValueError("Invalid or unsupported audio file.") from exc
raise
# Decode audio bytes. For container formats (MP4, M4A, WebM) that
# soundfile cannot detect from a BytesIO stream, _load_audio_bytes
# transparently falls back to ffmpeg via an in-memory fd.
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = load_audio_bytes(audio_data, sr=self.asr_config.sample_rate)

duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (
Expand Down
106 changes: 106 additions & 0 deletions vllm/entrypoints/openai/speech_to_text/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Audio decoding utilities for the speech-to-text endpoints."""

import io

import numpy as np
import torchaudio

from vllm.logger import init_logger
from vllm.utils.import_utils import PlaceholderModule

try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]

try:
import soundfile as sf
except ImportError:
sf = PlaceholderModule("soundfile") # type: ignore[assignment]

logger = init_logger(__name__)

# Public libsndfile error codes exposed via ``soundfile.LibsndfileError.code``.
# soundfile is librosa's primary backend. These codes indicate that the audio
# data itself is problematic (unrecognised container, corrupt file, or
# unsupported encoding) rather than a transient server error.
# 1 = unrecognised format, 3 = malformed file, 4 = unsupported encoding
_BAD_SF_CODES = {1, 3, 4}


def _decode_audio_bytes_torchaudio(
audio_data: bytes,
sr: int,
) -> tuple[np.ndarray, int]:
"""Decode audio bytes to mono float32 PCM via torchaudio, in-process.

``torchaudio.load`` (backed by TorchCodec / FFmpeg) can decode
container formats (MP4, M4A, WebM) directly from a ``BytesIO``
buffer without spawning a subprocess. The decoded waveform is
down-mixed to mono and resampled to *sr* Hz, matching the return
convention of ``librosa.load``.
"""
buf = io.BytesIO(audio_data)
waveform, orig_sr = torchaudio.load(buf)

# Down-mix to mono (average across channels).
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)

# Resample to the target sample rate when necessary.
if orig_sr != sr:
waveform = torchaudio.functional.resample(
waveform, orig_freq=orig_sr, new_freq=sr
)

# Squeeze channel dim → 1-D float32 numpy array (same as librosa.load).
y = waveform.squeeze(0).numpy()
if y.size == 0:
raise RuntimeError(
"torchaudio produced no audio samples (file may be empty or corrupt)"
)
return y, sr


def load_audio_bytes(
audio_data: bytes,
sr: int | float,
) -> tuple[np.ndarray, int]:
"""Load audio from raw bytes, with an in-process torchaudio fallback.

First tries ``librosa.load(BytesIO(...))`` which works for formats
that *soundfile* can auto-detect (WAV, FLAC, MP3, OGG, ...). If
that fails with a ``LibsndfileError`` indicating an unrecognised or
unsupported format (typically container formats like MP4/M4A/WebM),
the bytes are decoded in-process via ``torchaudio`` (backed by
TorchCodec / FFmpeg) which handles these containers natively.
"""
sr = int(sr)

# Fast path: librosa + soundfile (works for most formats).
try:
with io.BytesIO(audio_data) as buf:
return librosa.load(buf, sr=sr) # type: ignore[return-value]
except sf.LibsndfileError as exc:
# Only fall back for known format-detection failures.
# Re-raise anything else (e.g. corrupt but recognised format).
if exc.code not in _BAD_SF_CODES:
raise
logger.debug(
"librosa/soundfile could not decode audio from BytesIO "
"(code=%s: %s); falling back to torchaudio in-process decode",
exc.code,
exc,
)

# Fallback: torchaudio in-process decode (no subprocess overhead).
try:
return _decode_audio_bytes_torchaudio(audio_data, sr)
except Exception as ta_exc:
logger.debug(
"torchaudio fallback also failed: %s",
ta_exc,
)
raise ValueError("Invalid or unsupported audio file.") from ta_exc
Loading