Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-picked 2.0.1] Fix path-like object support in FFmpeg dispatcher #3243

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import itertools
import os
import pathlib
import tarfile
from contextlib import contextmanager
from functools import partial
Expand Down Expand Up @@ -35,6 +36,24 @@
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")

def test_pathlike(self):
"""FFmpeg dispatcher can query audio data from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)

info = self._info(pathlib.Path(path))
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)

@parameterized.expand(
list(
itertools.product(
Expand Down
16 changes: 16 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import itertools
import pathlib
import tarfile
from functools import partial

Expand Down Expand Up @@ -125,6 +126,21 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
class TestLoad(LoadTestBase):
"""Test the correctness of `self._load` for various formats"""

def test_pathlike(self):
"""FFmpeg dispatcher can load waveform from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)

waveform, sr = self._load(pathlib.Path(path))
self.assertEqual(sr, sample_rate)
self.assertEqual(waveform, data)

@parameterized.expand(
list(
itertools.product(
Expand Down
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import os
import pathlib
import subprocess
import sys
from functools import partial
Expand Down Expand Up @@ -146,6 +147,17 @@ def assert_save_consistency(
@skipIfNoExec("ffmpeg")
@skipIfNoFFmpeg
class SaveTest(SaveTestBase):
def test_pathlike(self):
"""FFmpeg dispatcher can save audio data to pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
self._save(pathlib.Path(path), data, sample_rate)

@nested_params(
["path", "fileobj", "bytesio"],
[
Expand Down
6 changes: 3 additions & 3 deletions torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s
if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
else:
metadata = info_audio(uri, format)
metadata = info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
Expand All @@ -108,7 +108,7 @@ def load(
buffer_size,
)
else:
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)

@staticmethod
def save(
Expand All @@ -122,7 +122,7 @@ def save(
buffer_size: int = 4096,
) -> None:
save_audio(
uri,
os.path.normpath(uri),
src,
sample_rate,
channels_first,
Expand Down