Skip to content

Commit

Permalink
Fix file-like object support in FFmpeg dispatcher
Browse files Browse the repository at this point in the history
In dispatcher mode, FFmpeg backend does not handle file-like object,
and C++ implementation raises an issue.

This commit fixes it by normalizing file-like object to string.
  • Loading branch information
mthrok committed Apr 5, 2023
1 parent 5053aa7 commit d4fee19
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
20 changes: 20 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import itertools
import os
import pathlib
import tarfile
from contextlib import contextmanager
from functools import partial
Expand Down Expand Up @@ -35,6 +36,25 @@
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")

def test_pathlike(self):
"""FFmpeg dispatcher can query audio data from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)

info = self._info(pathlib.Path(path))
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)


@parameterized.expand(
list(
itertools.product(
Expand Down
16 changes: 16 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import itertools
import pathlib
import tarfile
from functools import partial

Expand Down Expand Up @@ -125,6 +126,21 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
class TestLoad(LoadTestBase):
"""Test the correctness of `self._load` for various formats"""

def test_pathlike(self):
"""FFmpeg dispatcher can load waveform from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)

waveform, sr = self._load(pathlib.Path(path))
self.assertEqual(sr, sample_rate)
self.assertEqual(waveform, data)

@parameterized.expand(
list(
itertools.product(
Expand Down
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import os
import pathlib
import subprocess
import sys
from functools import partial
Expand Down Expand Up @@ -146,6 +147,17 @@ def assert_save_consistency(
@skipIfNoExec("ffmpeg")
@skipIfNoFFmpeg
class SaveTest(SaveTestBase):
def test_pathlike(self):
"""FFmpeg dispatcher can save audio data to pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
self._save(pathlib.Path(path), data, sample_rate)

@nested_params(
["path", "fileobj", "bytesio"],
[
Expand Down
6 changes: 3 additions & 3 deletions torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s
if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
else:
metadata = info_audio(uri, format)
metadata = info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
Expand All @@ -108,7 +108,7 @@ def load(
buffer_size,
)
else:
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)

@staticmethod
def save(
Expand All @@ -122,7 +122,7 @@ def save(
buffer_size: int = 4096,
) -> None:
save_audio(
uri,
os.path.normpath(uri),
src,
sample_rate,
channels_first,
Expand Down

0 comments on commit d4fee19

Please sign in to comment.