-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[Bugfix][Frontend] Fix audio transcription for MP4, M4A, and WebM formats #35109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vllm-bot
merged 1 commit into
vllm-project:main
from
seanmamasde:fix/audio-transcription-mp4-m4a-webm
Mar 14, 2026
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Audio decoding utilities for the speech-to-text endpoints.""" | ||
|
|
||
| import io | ||
|
|
||
| import numpy as np | ||
| import torchaudio | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.utils.import_utils import PlaceholderModule | ||
|
|
||
| try: | ||
| import librosa | ||
| except ImportError: | ||
| librosa = PlaceholderModule("librosa") # type: ignore[assignment] | ||
|
|
||
| try: | ||
| import soundfile as sf | ||
| except ImportError: | ||
| sf = PlaceholderModule("soundfile") # type: ignore[assignment] | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| # Public libsndfile error codes exposed via ``soundfile.LibsndfileError.code``. | ||
| # soundfile is librosa's primary backend. These codes indicate that the audio | ||
| # data itself is problematic (unrecognised container, corrupt file, or | ||
| # unsupported encoding) rather than a transient server error. | ||
| # 1 = unrecognised format, 3 = malformed file, 4 = unsupported encoding | ||
| _BAD_SF_CODES = {1, 3, 4} | ||
|
|
||
|
|
||
| def _decode_audio_bytes_torchaudio( | ||
| audio_data: bytes, | ||
| sr: int, | ||
| ) -> tuple[np.ndarray, int]: | ||
| """Decode audio bytes to mono float32 PCM via torchaudio, in-process. | ||
|
|
||
| ``torchaudio.load`` (backed by TorchCodec / FFmpeg) can decode | ||
| container formats (MP4, M4A, WebM) directly from a ``BytesIO`` | ||
| buffer without spawning a subprocess. The decoded waveform is | ||
| down-mixed to mono and resampled to *sr* Hz, matching the return | ||
| convention of ``librosa.load``. | ||
| """ | ||
| buf = io.BytesIO(audio_data) | ||
| waveform, orig_sr = torchaudio.load(buf) | ||
|
|
||
| # Down-mix to mono (average across channels). | ||
| if waveform.shape[0] > 1: | ||
| waveform = waveform.mean(dim=0, keepdim=True) | ||
|
|
||
| # Resample to the target sample rate when necessary. | ||
| if orig_sr != sr: | ||
| waveform = torchaudio.functional.resample( | ||
| waveform, orig_freq=orig_sr, new_freq=sr | ||
| ) | ||
|
|
||
| # Squeeze channel dim → 1-D float32 numpy array (same as librosa.load). | ||
| y = waveform.squeeze(0).numpy() | ||
| if y.size == 0: | ||
| raise RuntimeError( | ||
| "torchaudio produced no audio samples (file may be empty or corrupt)" | ||
| ) | ||
| return y, sr | ||
|
|
||
|
|
||
| def load_audio_bytes( | ||
| audio_data: bytes, | ||
| sr: int | float, | ||
| ) -> tuple[np.ndarray, int]: | ||
| """Load audio from raw bytes, with an in-process torchaudio fallback. | ||
|
|
||
| First tries ``librosa.load(BytesIO(...))`` which works for formats | ||
| that *soundfile* can auto-detect (WAV, FLAC, MP3, OGG, ...). If | ||
| that fails with a ``LibsndfileError`` indicating an unrecognised or | ||
| unsupported format (typically container formats like MP4/M4A/WebM), | ||
| the bytes are decoded in-process via ``torchaudio`` (backed by | ||
| TorchCodec / FFmpeg) which handles these containers natively. | ||
| """ | ||
| sr = int(sr) | ||
|
|
||
| # Fast path: librosa + soundfile (works for most formats). | ||
| try: | ||
| with io.BytesIO(audio_data) as buf: | ||
| return librosa.load(buf, sr=sr) # type: ignore[return-value] | ||
| except sf.LibsndfileError as exc: | ||
| # Only fall back for known format-detection failures. | ||
| # Re-raise anything else (e.g. corrupt but recognised format). | ||
| if exc.code not in _BAD_SF_CODES: | ||
| raise | ||
| logger.debug( | ||
| "librosa/soundfile could not decode audio from BytesIO " | ||
| "(code=%s: %s); falling back to torchaudio in-process decode", | ||
| exc.code, | ||
| exc, | ||
| ) | ||
|
|
||
| # Fallback: torchaudio in-process decode (no subprocess overhead). | ||
| try: | ||
| return _decode_audio_bytes_torchaudio(audio_data, sr) | ||
| except Exception as ta_exc: | ||
| logger.debug( | ||
| "torchaudio fallback also failed: %s", | ||
| ta_exc, | ||
| ) | ||
| raise ValueError("Invalid or unsupported audio file.") from ta_exc |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit worried that
torchcodecwill break audio support on GB200 + aarch64 CPU, because it only distributes x86_64 manylinux wheels (https://pypi.org/project/torchcodec/#files).I opened #37061 to revert this PR and use
pyavfor video fallback instead.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually investigated this a bit back:
At the time of implementation,
torchaudioseems like the best bet since it doesn't introduce extra deps and is an in-process conversion (as opposed to tempfile, subprocess w/ ffmpeg). But it seems that starting with torchaudiov2.9.0+it uses torchcodec fortorchaudio.save()andtorchaudio.load()