diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py index 3e75e8e160..58a085636b 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py @@ -1,6 +1,7 @@ import io import itertools import os +import pathlib import tarfile from contextlib import contextmanager from functools import partial @@ -35,6 +36,24 @@ class TestInfo(TempDirMixin, PytorchTestCase): _info = partial(get_info_func(), backend="ffmpeg") + def test_pathlike(self): + """FFmpeg dispatcher can query audio data from pathlike object""" + sample_rate = 16000 + dtype = "float32" + num_channels = 2 + duration = 1 + + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + + info = self._info(pathlib.Path(path)) + assert info.sample_rate == sample_rate + assert info.num_frames == sample_rate * duration + assert info.num_channels == num_channels + assert info.bits_per_sample == sox_utils.get_bit_depth(dtype) + assert info.encoding == get_encoding("wav", dtype) + @parameterized.expand( list( itertools.product( diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py index 68eae752c9..667be0276b 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py @@ -1,5 +1,6 @@ import io import itertools +import pathlib import tarfile from functools import partial @@ -125,6 +126,21 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration): class TestLoad(LoadTestBase): """Test the correctness of `self._load` for various formats""" + def test_pathlike(self): + """FFmpeg dispatcher can load waveform from pathlike object""" + sample_rate = 16000 + dtype = "float32" + num_channels = 2 + duration = 1 + + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + save_wav(path, data, sample_rate) + + waveform, sr = self._load(pathlib.Path(path)) + self.assertEqual(sr, sample_rate) + self.assertEqual(waveform, data) + @parameterized.expand( list( itertools.product( diff --git a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py index 23a74eeb51..419d751532 100644 --- a/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py +++ b/test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py @@ -1,5 +1,6 @@ import io import os +import pathlib import subprocess import sys from functools import partial @@ -146,6 +147,17 @@ def assert_save_consistency( @skipIfNoExec("ffmpeg") @skipIfNoFFmpeg class SaveTest(SaveTestBase): + def test_pathlike(self): + """FFmpeg dispatcher can save audio data to pathlike object""" + sample_rate = 16000 + dtype = "float32" + num_channels = 2 + duration = 1 + + path = self.get_temp_path("data.wav") + data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate) + self._save(pathlib.Path(path), data, sample_rate) + @nested_params( ["path", "fileobj", "bytesio"], [ diff --git a/torchaudio/_backend/utils.py b/torchaudio/_backend/utils.py index 950f0ee9ba..70361b8aa4 100644 --- a/torchaudio/_backend/utils.py +++ b/torchaudio/_backend/utils.py @@ -82,7 +82,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s if hasattr(uri, "read"): metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size) else: - metadata = info_audio(uri, format) + metadata = info_audio(os.path.normpath(uri), format) metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample) metadata.encoding = _map_encoding(metadata.encoding) return metadata @@ -108,7 +108,7 @@ def load( buffer_size, ) else: - return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format) + return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format) @staticmethod def save( diff --git a/torchaudio/io/_compat.py b/torchaudio/io/_compat.py index cc44ed26e0..fce454142a 100644 --- a/torchaudio/io/_compat.py +++ b/torchaudio/io/_compat.py @@ -216,8 +216,11 @@ def save_audio( bits_per_sample: Optional[int] = None, buffer_size: int = 4096, ) -> None: - if hasattr(uri, "write") and format is None: - raise RuntimeError("'format' is required when saving to file object.") + if hasattr(uri, "write"): + if format is None: + raise RuntimeError("'format' is required when saving to file object.") + else: + uri = os.path.normpath(uri) s = StreamWriter(uri, format=format, buffer_size=buffer_size) if format is None: tokens = str(uri).split(".")