Skip to content

Commit

Permalink
Handle vorbis extension in ffmpeg backend
Browse files Browse the repository at this point in the history
When dealing with vorbis format, FFmpeg expects "ogg" container/extension
with "vorbis" encoder. It does not recognize "vorbis" container/extension.

libsox-based torchaudio I/O used to handle vorbis extension.
This commit adds support to vorbis as extension for those cases with FFmpeg backend.
  • Loading branch information
mthrok committed May 29, 2023
1 parent af932cc commit f161094
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions torchaudio/io/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def load_audio_fileobj(
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
file_format = "ogg" if format == "vorbis" else format
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, file_format, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio_fileobj(s, filter, channels_first)
Expand Down Expand Up @@ -204,25 +205,32 @@ def save_audio(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
ext = None
if hasattr(uri, "write"):
if format is None:
raise RuntimeError("'format' is required when saving to file object.")
else:
uri = os.path.normpath(uri)
s = StreamWriter(uri, format=format, buffer_size=buffer_size)
if format is None:
tokens = str(uri).split(".")
if len(tokens) > 1:
format = tokens[-1].lower()
if tokens := str(uri).split(".")[1:]:
ext = tokens[-1].lower()

# FFpmeg does not recognize vorbis extension, while libsox used to do.
# For the sake of bakward compatibility, (and the simplicity),
# we support the case where users want to do save("foo.vorbis")
if (ext == "vorbis" and format is None) or hasattr(uri, "write") and format == "vorbis":
file_format = "ogg"
else:
file_format = format
s = StreamWriter(uri, format=file_format, buffer_size=buffer_size)

if channels_first:
src = src.T
s.add_audio_stream(
sample_rate,
num_channels=src.size(-1),
format=_get_sample_format(src.dtype),
encoder=_get_encoder(src.dtype, format, encoding, bits_per_sample),
encoder_format=_get_encoder_format(format, bits_per_sample),
encoder=_get_encoder(src.dtype, ext, encoding, bits_per_sample),
encoder_format=_get_encoder_format(ext, bits_per_sample),
)
with s.open():
s.write_audio_chunk(0, src)

0 comments on commit f161094

Please sign in to comment.