Skip to content

Commit

Permalink
Refactor _compat.save function
Browse files Browse the repository at this point in the history
When dealing with vorbis format, FFmpeg expects "ogg" container/extension
with "vorbis" encoder. It does not recognize "vorbis" container/extension.

libsox-based torchaudio I/O used to handle vorbis extension.

This commit refactors the internal of save argument and adds support to vorbis
as extension for those cases with FFmpeg backend.

This also fixes the case of mp3

#3385
  • Loading branch information
mthrok committed May 30, 2023
1 parent af932cc commit 4a31b59
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 46 deletions.
32 changes: 19 additions & 13 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from parameterized import parameterized
from torchaudio._backend.utils import get_save_func
from torchaudio.io._compat import _get_encoder, _get_encoder_format
from torchaudio.io._compat import _parse_save_args

from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func
from torchaudio_unittest.common_utils import (
Expand All @@ -24,12 +24,14 @@
)


def _convert_audio_file(src_path, dst_path, format=None, acodec=None):
command = ["ffmpeg", "-y", "-i", src_path, "-strict", "-2"]
if format:
command += ["-sample_fmt", format]
if acodec:
command += ["-acodec", acodec]
def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None):
command = ["ffmpeg", "-hide_banner", "-y", "-i", src_path, "-strict", "-2"]
if muxer:
command += ["-f", muxer]
if encoder:
command += ["-acodec", encoder]
if sample_fmt:
command += ["-sample_fmt", sample_fmt]
command += [dst_path]
print(" ".join(command), file=sys.stderr)
subprocess.run(command, check=True)
Expand Down Expand Up @@ -97,11 +99,15 @@ def assert_save_consistency(
data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames)
save_wav(src_path, data, sample_rate)

import torchaudio
torchaudio.utils.ffmpeg_utils.set_log_level(32)
# 2.1. Convert the original wav to target format with torchaudio
data = load_wav(src_path, normalize=False)[0]
if test_mode == "path":
self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample)
ext = format
self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample)
elif test_mode == "fileobj":
ext = None
with open(tgt_path, "bw") as file_:
self._save(
file_,
Expand All @@ -113,6 +119,7 @@ def assert_save_consistency(
)
elif test_mode == "bytesio":
file_ = io.BytesIO()
ext = None
self._save(
file_,
data,
Expand All @@ -127,16 +134,15 @@ def assert_save_consistency(
else:
raise ValueError(f"Unexpected test mode: {test_mode}")
# 2.2. Convert the target format to wav with ffmpeg
_convert_audio_file(tgt_path, tst_path, acodec="pcm_f32le")
_convert_audio_file(tgt_path, tst_path, encoder="pcm_f32le")
# 2.3. Load with SciPy
found = load_wav(tst_path, normalize=False)[0]

# 3.1. Convert the original wav to target format with ffmpeg
acodec = _get_encoder(data.dtype, format, encoding, bits_per_sample)
sample_fmt = _get_encoder_format(format, bits_per_sample)
_convert_audio_file(src_path, sox_path, acodec=acodec, format=sample_fmt)
muxer, encoder, sample_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)
_convert_audio_file(src_path, sox_path, muxer=muxer, encoder=encoder, sample_fmt=sample_fmt)
# 3.2. Convert the target format to wav with ffmpeg
_convert_audio_file(sox_path, ref_path, acodec="pcm_f32le")
_convert_audio_file(sox_path, ref_path, encoder="pcm_f32le")
# 3.3. Load with SciPy
expected = load_wav(ref_path, normalize=False)[0]

Expand Down
100 changes: 67 additions & 33 deletions torchaudio/io/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def load_audio_fileobj(
format: Optional[str] = None,
buffer_size: int = 4096,
) -> Tuple[torch.Tensor, int]:
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size)
file_format = "ogg" if format == "vorbis" else format
s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, file_format, None, buffer_size)
sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate)
filter = _get_load_filter(frame_offset, num_frames, convert)
waveform = _load_audio_fileobj(s, filter, channels_first)
Expand Down Expand Up @@ -131,7 +132,7 @@ def _native_endianness() -> str:
return "be"


def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int) -> str:
def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str:
if bits_per_sample not in {None, 8, 16, 24, 32, 64}:
raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.")
endianness = _native_endianness()
Expand All @@ -148,49 +149,80 @@ def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int
if bits_per_sample == 8:
raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.")
return f"pcm_s{bits_per_sample}{endianness}"
elif encoding == "PCM_U":
if encoding == "PCM_U":
if bits_per_sample in (None, 8):
return "pcm_u8"
raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.")
elif encoding == "PCM_F":
if encoding == "PCM_F":
if not bits_per_sample:
bits_per_sample = 32
if bits_per_sample in (32, 64):
return f"pcm_f{bits_per_sample}{endianness}"
raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.")
elif encoding == "ULAW":
if encoding == "ULAW":
if bits_per_sample in (None, 8):
return "pcm_mulaw"
raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.")
elif encoding == "ALAW":
if encoding == "ALAW":
if bits_per_sample in (None, 8):
return "pcm_alaw"
raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.")
raise ValueError(f"WAV encoding {encoding} is not supported.")


def _get_encoder(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int) -> str:
if format == "wav":
return _get_encoder_for_wav(dtype, encoding, bits_per_sample)
if format == "flac":
return "flac"
if format in ("ogg", "vorbis"):
if encoding or bits_per_sample:
raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.")
return "vorbis"
return format
def _get_flac_sample_fmt(bps):
if bps is None or bps == 16:
return "s16"
if bps == 24:
return "s32"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).")


def _get_encoder_format(format: str, bits_per_sample: Optional[int]) -> str:
if format == "flac":
if not bits_per_sample:
return "s16"
if bits_per_sample == 24:
return "s32"
if bits_per_sample == 16:
return "s16"
raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bits_per_sample} specified).")
return None
def _parse_save_args(
ext: Optional[str],
format: Optional[str],
encoding: Optional[str],
bps: Optional[int],
):
# torchaudio's save function accepts the followings, which do not 1to1 map
# to FFmpeg.
#
# - format: audio format
# - bits_per_sample: encoder sample format
# - encoding: such as PCM_U8.
#
# In FFmpeg, format is specified with the following three (and more)
#
# - muxer: could be audio format or container format.
# the one we passed to the constructor of StreamWriter
# - encoder: the audio encoder used to encode audio
# - encoder sample format: the format used by encoder to encode audio.
#
# If encoder sample format is different from source sample format, StreamWriter
# will insert a filter automatically.
#
if format == "wav" or (format is None and ext == "wav"):
# wav is special because it supports different encoding through encoders
# each encoder only supports one encoder format
muxer = "wav"
encoder = _get_encoder_for_wav(encoding, bps)
sample_fmt = None
elif format == "vorbis" or (format is None and ext == "vorbis"):
# FFpmeg does not recognize vorbis extension, while libsox used to do.
# For the sake of bakward compatibility, (and the simplicity),
# we support the case where users want to do save("foo.vorbis")
muxer = "ogg"
encoder = "vorbis"
sample_fmt = None
else:
muxer = format
encoder = None
sample_fmt = None
if format == "flac" or format is None and ext == "flac":
sample_fmt = _get_flac_sample_fmt(bps)
if format == "ogg" or format is None and ext == "ogg":
sample_fmt = _get_flac_sample_fmt(bps)
return muxer, encoder, sample_fmt


# NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript.
Expand All @@ -204,25 +236,27 @@ def save_audio(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
ext = None
if hasattr(uri, "write"):
if format is None:
raise RuntimeError("'format' is required when saving to file object.")
else:
uri = os.path.normpath(uri)
s = StreamWriter(uri, format=format, buffer_size=buffer_size)
if format is None:
tokens = str(uri).split(".")
if len(tokens) > 1:
format = tokens[-1].lower()
if tokens := str(uri).split(".")[1:]:
ext = tokens[-1].lower()

muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample)

if channels_first:
src = src.T

s = StreamWriter(uri, format=muxer, buffer_size=buffer_size)
s.add_audio_stream(
sample_rate,
num_channels=src.size(-1),
format=_get_sample_format(src.dtype),
encoder=_get_encoder(src.dtype, format, encoding, bits_per_sample),
encoder_format=_get_encoder_format(format, bits_per_sample),
encoder=encoder,
encoder_format=enc_fmt,
)
with s.open():
s.write_audio_chunk(0, src)

0 comments on commit 4a31b59

Please sign in to comment.