diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py index ef0e56f0e5d..9b6fc899e48 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py @@ -8,7 +8,7 @@ import torch from parameterized import parameterized from torchaudio._backend.utils import get_save_func -from torchaudio.io._compat import _get_encoder, _get_encoder_format +from torchaudio.io._compat import _parse_save_args from torchaudio_unittest.backend.dispatcher.sox.common import get_enc_params, name_func from torchaudio_unittest.common_utils import ( @@ -24,12 +24,14 @@ ) -def _convert_audio_file(src_path, dst_path, format=None, acodec=None): - command = ["ffmpeg", "-y", "-i", src_path, "-strict", "-2"] - if format: - command += ["-sample_fmt", format] - if acodec: - command += ["-acodec", acodec] +def _convert_audio_file(src_path, dst_path, muxer=None, encoder=None, sample_fmt=None): + command = ["ffmpeg", "-hide_banner", "-y", "-i", src_path, "-strict", "-2"] + if muxer: + command += ["-f", muxer] + if encoder: + command += ["-acodec", encoder] + if sample_fmt: + command += ["-sample_fmt", sample_fmt] command += [dst_path] print(" ".join(command), file=sys.stderr) subprocess.run(command, check=True) @@ -97,11 +99,15 @@ def assert_save_consistency( data = get_wav_data(src_dtype, num_channels, normalize=False, num_frames=num_frames) save_wav(src_path, data, sample_rate) + import torchaudio + torchaudio.utils.ffmpeg_utils.set_log_level(32) # 2.1. Convert the original wav to target format with torchaudio data = load_wav(src_path, normalize=False)[0] if test_mode == "path": - self._save(tgt_path, data, sample_rate, encoding=encoding, bits_per_sample=bits_per_sample) + ext = format + self._save(tgt_path, data, sample_rate, format=format, encoding=encoding, bits_per_sample=bits_per_sample) elif test_mode == "fileobj": + ext = None with open(tgt_path, "bw") as file_: self._save( file_, @@ -113,6 +119,7 @@ def assert_save_consistency( ) elif test_mode == "bytesio": file_ = io.BytesIO() + ext = None self._save( file_, data, @@ -127,16 +134,15 @@ def assert_save_consistency( else: raise ValueError(f"Unexpected test mode: {test_mode}") # 2.2. Convert the target format to wav with ffmpeg - _convert_audio_file(tgt_path, tst_path, acodec="pcm_f32le") + _convert_audio_file(tgt_path, tst_path, encoder="pcm_f32le") # 2.3. Load with SciPy found = load_wav(tst_path, normalize=False)[0] # 3.1. Convert the original wav to target format with ffmpeg - acodec = _get_encoder(data.dtype, format, encoding, bits_per_sample) - sample_fmt = _get_encoder_format(format, bits_per_sample) - _convert_audio_file(src_path, sox_path, acodec=acodec, format=sample_fmt) + muxer, encoder, sample_fmt = _parse_save_args(ext, format, encoding, bits_per_sample) + _convert_audio_file(src_path, sox_path, muxer=muxer, encoder=encoder, sample_fmt=sample_fmt) # 3.2. Convert the target format to wav with ffmpeg - _convert_audio_file(sox_path, ref_path, acodec="pcm_f32le") + _convert_audio_file(sox_path, ref_path, encoder="pcm_f32le") # 3.3. Load with SciPy expected = load_wav(ref_path, normalize=False)[0] diff --git a/torchaudio/io/_compat.py b/torchaudio/io/_compat.py index 7b122cbcc8c..af6ee4c522d 100644 --- a/torchaudio/io/_compat.py +++ b/torchaudio/io/_compat.py @@ -102,7 +102,8 @@ def load_audio_fileobj( format: Optional[str] = None, buffer_size: int = 4096, ) -> Tuple[torch.Tensor, int]: - s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, format, None, buffer_size) + file_format = "ogg" if format == "vorbis" else format + s = torchaudio.lib._torchaudio_ffmpeg.StreamReaderFileObj(src, file_format, None, buffer_size) sample_rate = int(s.get_src_stream_info(s.find_best_audio_stream()).sample_rate) filter = _get_load_filter(frame_offset, num_frames, convert) waveform = _load_audio_fileobj(s, filter, channels_first) @@ -131,7 +132,7 @@ def _native_endianness() -> str: return "be" -def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int) -> str: +def _get_encoder_for_wav(encoding: str, bits_per_sample: int) -> str: if bits_per_sample not in {None, 8, 16, 24, 32, 64}: raise ValueError(f"Invalid bits_per_sample {bits_per_sample} for WAV encoding.") endianness = _native_endianness() @@ -148,49 +149,80 @@ def _get_encoder_for_wav(dtype: torch.dtype, encoding: str, bits_per_sample: int if bits_per_sample == 8: raise ValueError("For WAV signed PCM, 8-bit encoding is not supported.") return f"pcm_s{bits_per_sample}{endianness}" - elif encoding == "PCM_U": + if encoding == "PCM_U": if bits_per_sample in (None, 8): return "pcm_u8" raise ValueError("For WAV unsigned PCM, only 8-bit encoding is supported.") - elif encoding == "PCM_F": + if encoding == "PCM_F": if not bits_per_sample: bits_per_sample = 32 if bits_per_sample in (32, 64): return f"pcm_f{bits_per_sample}{endianness}" raise ValueError("For WAV float PCM, only 32- and 64-bit encodings are supported.") - elif encoding == "ULAW": + if encoding == "ULAW": if bits_per_sample in (None, 8): return "pcm_mulaw" raise ValueError("For WAV PCM mu-law, only 8-bit encoding is supported.") - elif encoding == "ALAW": + if encoding == "ALAW": if bits_per_sample in (None, 8): return "pcm_alaw" raise ValueError("For WAV PCM A-law, only 8-bit encoding is supported.") raise ValueError(f"WAV encoding {encoding} is not supported.") -def _get_encoder(dtype: torch.dtype, format: str, encoding: str, bits_per_sample: int) -> str: - if format == "wav": - return _get_encoder_for_wav(dtype, encoding, bits_per_sample) - if format == "flac": - return "flac" - if format in ("ogg", "vorbis"): - if encoding or bits_per_sample: - raise ValueError("ogg/vorbis does not support encoding/bits_per_sample.") - return "vorbis" - return format +def _get_flac_sample_fmt(bps): + if bps is None or bps == 16: + return "s16" + if bps == 24: + return "s32" + raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bps} specified).") -def _get_encoder_format(format: str, bits_per_sample: Optional[int]) -> str: - if format == "flac": - if not bits_per_sample: - return "s16" - if bits_per_sample == 24: - return "s32" - if bits_per_sample == 16: - return "s16" - raise ValueError(f"FLAC only supports bits_per_sample values of 16 and 24 ({bits_per_sample} specified).") - return None +def _parse_save_args( + ext: Optional[str], + format: Optional[str], + encoding: Optional[str], + bps: Optional[int], +): + # torchaudio's save function accepts the followings, which do not 1to1 map + # to FFmpeg. + # + # - format: audio format + # - bits_per_sample: encoder sample format + # - encoding: such as PCM_U8. + # + # In FFmpeg, format is specified with the following three (and more) + # + # - muxer: could be audio format or container format. + # the one we passed to the constructor of StreamWriter + # - encoder: the audio encoder used to encode audio + # - encoder sample format: the format used by encoder to encode audio. + # + # If encoder sample format is different from source sample format, StreamWriter + # will insert a filter automatically. + # + if format == "wav" or (format is None and ext == "wav"): + # wav is special because it supports different encoding through encoders + # each encoder only supports one encoder format + muxer = "wav" + encoder = _get_encoder_for_wav(encoding, bps) + sample_fmt = None + elif format == "vorbis" or (format is None and ext == "vorbis"): + # FFpmeg does not recognize vorbis extension, while libsox used to do. + # For the sake of bakward compatibility, (and the simplicity), + # we support the case where users want to do save("foo.vorbis") + muxer = "ogg" + encoder = "vorbis" + sample_fmt = None + else: + muxer = format + encoder = None + sample_fmt = None + if format == "flac" or format is None and ext == "flac": + sample_fmt = _get_flac_sample_fmt(bps) + if format == "ogg" or format is None and ext == "ogg": + sample_fmt = _get_flac_sample_fmt(bps) + return muxer, encoder, sample_fmt # NOTE: in contrast to load_audio* and info_audio*, this function is NOT compatible with TorchScript. @@ -204,25 +236,27 @@ def save_audio( bits_per_sample: Optional[int] = None, buffer_size: int = 4096, ) -> None: + ext = None if hasattr(uri, "write"): if format is None: raise RuntimeError("'format' is required when saving to file object.") else: uri = os.path.normpath(uri) - s = StreamWriter(uri, format=format, buffer_size=buffer_size) - if format is None: - tokens = str(uri).split(".") - if len(tokens) > 1: - format = tokens[-1].lower() + if tokens := str(uri).split(".")[1:]: + ext = tokens[-1].lower() + + muxer, encoder, enc_fmt = _parse_save_args(ext, format, encoding, bits_per_sample) if channels_first: src = src.T + + s = StreamWriter(uri, format=muxer, buffer_size=buffer_size) s.add_audio_stream( sample_rate, num_channels=src.size(-1), format=_get_sample_format(src.dtype), - encoder=_get_encoder(src.dtype, format, encoding, bits_per_sample), - encoder_format=_get_encoder_format(format, bits_per_sample), + encoder=encoder, + encoder_format=enc_fmt, ) with s.open(): s.write_audio_chunk(0, src)