From d4fee19bcc92f5602d281e58b5087bb1d1605e25 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:39:40 -0400 Subject: [PATCH] Fix file-like object support in FFmpeg dispatcher In dispatcher mode, FFmpeg backend does not handle file-like object, and C++ implementation raises an issue. This commit fixes it by normalizing file-like object to string. --- .../backend/dispatcher/ffmpeg/info_test.py | 20 +++++++++++++++++++ .../backend/dispatcher/ffmpeg/load_test.py | 16 +++++++++++++++ .../backend/dispatcher/ffmpeg/save_test.py | 12 +++++++++++ torchaudio/_backend/utils.py | 6 +++--- 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py index 3e75e8e160b..c233178a278 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py @@ -1,6 +1,7 @@ import io import itertools import os +import pathlib import tarfile from contextlib import contextmanager from functools import partial @@ -35,6 +36,25 @@ class TestInfo(TempDirMixin, PytorchTestCase): _info = partial(get_info_func(), backend="ffmpeg") + def test_pathlike(self): + """FFmpeg dispatcher can query audio data from pathlike object""" + sample_rate = 16000 + dtype = "float32" + num_channels = 2 + duration = 1 + + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + + info = self._info(pathlib.Path(path)) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) + assert info.encoding == get_encoding("wav", dtype) + + @parameterized.expand( list( itertools.product( diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py index 68eae752c92..667be0276bd 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py @@ -1,5 +1,6 @@ import io import itertools +import pathlib import tarfile from functools import partial @@ -125,6 +126,21 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): class TestLoad(LoadTestBase): """Test the correctness of `self._load` for various formats""" + def test_pathlike(self): + """FFmpeg dispatcher can load waveform from pathlike object""" + sample_rate = 16000 + dtype = "float32" + num_channels = 2 + duration = 1 + + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + + waveform, sr = self._load(pathlib.Path(path)) + self.assertEqual(sr, sample_rate) + self.assertEqual(waveform, data) + @parameterized.expand( list( itertools.product( diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py index 23a74eeb517..419d7515329 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py @@ -1,5 +1,6 @@ import io import os +import pathlib import subprocess import sys from functools import partial @@ -146,6 +147,17 @@ def assert_save_consistency( @skipIfNoExec("ffmpeg") @skipIfNoFFmpeg class SaveTest(SaveTestBase): + def test_pathlike(self): + """FFmpeg dispatcher can save audio data to pathlike object""" + sample_rate = 16000 + dtype = "float32" + num_channels = 2 + duration = 1 + + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + self._save(pathlib.Path(path), data, sample_rate) + @nested_params( ["path", "fileobj", "bytesio"], [ diff --git a/torchaudio/_backend/utils.py b/torchaudio/_backend/utils.py index 950f0ee9ba9..f7901300569 100644 --- a/torchaudio/_backend/utils.py +++ b/torchaudio/_backend/utils.py @@ -82,7 +82,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s if hasattr(uri, "read"): metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size) else: - metadata = info_audio(uri, format) + metadata = info_audio(os.path.normpath(uri), format) metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample) metadata.encoding = _map_encoding(metadata.encoding) return metadata @@ -108,7 +108,7 @@ def load( buffer_size, ) else: - return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format) + return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format) @staticmethod def save( @@ -122,7 +122,7 @@ def save( buffer_size: int = 4096, ) -> None: save_audio( - uri, + os.path.normpath(uri), src, sample_rate, channels_first,