Skip to content

Commit

Permalink
[Cherry-pick] Fix path-like object support in FFmpeg dispatcher (#3243,
Browse files Browse the repository at this point in the history
#3248) (#3245)

* Fix path-like object support in FFmpeg dispatcher (#3243)

Summary:
In dispatcher mode, FFmpeg backend does not handle file-like object, and C++ implementation raises an issue.

This commit fixes it by normalizing file-like object to string.

Pull Request resolved: #3243

Reviewed By: nateanl

Differential Revision: D44719280

Pulled By: mthrok

fbshipit-source-id: 9dae459e2a5fb4992b4ef53fe4829fe8c35b2edd

* Fix path normalization for StreamWriter-based save operation (#3248)

Summary:
Follow up of #3243. Save compat module had different semantics than info and load, which requires different way of performing path normalization.

Pull Request resolved: #3248

Reviewed By: hwangjeff

Differential Revision: D44774997

Pulled By: mthrok

fbshipit-source-id: 4b967ae3ca6b45850d455b8e95aaa31762c5457e
  • Loading branch information
mthrok authored Apr 7, 2023
1 parent a4ea69e commit d92216d
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 4 deletions.
19 changes: 19 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/info_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import itertools
import os
import pathlib
import tarfile
from contextlib import contextmanager
from functools import partial
Expand Down Expand Up @@ -35,6 +36,24 @@
class TestInfo(TempDirMixin, PytorchTestCase):
_info = partial(get_info_func(), backend="ffmpeg")

def test_pathlike(self):
"""FFmpeg dispatcher can query audio data from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)

info = self._info(pathlib.Path(path))
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == sox_utils.get_bit_depth(dtype)
assert info.encoding == get_encoding("wav", dtype)

@parameterized.expand(
list(
itertools.product(
Expand Down
16 changes: 16 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/load_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import itertools
import pathlib
import tarfile
from functools import partial

Expand Down Expand Up @@ -125,6 +126,21 @@ def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
class TestLoad(LoadTestBase):
"""Test the correctness of `self._load` for various formats"""

def test_pathlike(self):
"""FFmpeg dispatcher can load waveform from pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)

waveform, sr = self._load(pathlib.Path(path))
self.assertEqual(sr, sample_rate)
self.assertEqual(waveform, data)

@parameterized.expand(
list(
itertools.product(
Expand Down
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/backend/dispatcher/ffmpeg/save_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
import os
import pathlib
import subprocess
import sys
from functools import partial
Expand Down Expand Up @@ -146,6 +147,17 @@ def assert_save_consistency(
@skipIfNoExec("ffmpeg")
@skipIfNoFFmpeg
class SaveTest(SaveTestBase):
def test_pathlike(self):
"""FFmpeg dispatcher can save audio data to pathlike object"""
sample_rate = 16000
dtype = "float32"
num_channels = 2
duration = 1

path = self.get_temp_path("data.wav")
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
self._save(pathlib.Path(path), data, sample_rate)

@nested_params(
["path", "fileobj", "bytesio"],
[
Expand Down
4 changes: 2 additions & 2 deletions torchaudio/_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def info(uri: Union[BinaryIO, str, os.PathLike], format: Optional[str], buffer_s
if hasattr(uri, "read"):
metadata = info_audio_fileobj(uri, format, buffer_size=buffer_size)
else:
metadata = info_audio(uri, format)
metadata = info_audio(os.path.normpath(uri), format)
metadata.bits_per_sample = _get_bits_per_sample(metadata.encoding, metadata.bits_per_sample)
metadata.encoding = _map_encoding(metadata.encoding)
return metadata
Expand All @@ -108,7 +108,7 @@ def load(
buffer_size,
)
else:
return load_audio(uri, frame_offset, num_frames, normalize, channels_first, format)
return load_audio(os.path.normpath(uri), frame_offset, num_frames, normalize, channels_first, format)

@staticmethod
def save(
Expand Down
7 changes: 5 additions & 2 deletions torchaudio/io/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,11 @@ def save_audio(
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
) -> None:
if hasattr(uri, "write") and format is None:
raise RuntimeError("'format' is required when saving to file object.")
if hasattr(uri, "write"):
if format is None:
raise RuntimeError("'format' is required when saving to file object.")
else:
uri = os.path.normpath(uri)
s = StreamWriter(uri, format=format, buffer_size=buffer_size)
if format is None:
tokens = str(uri).split(".")
Expand Down

0 comments on commit d92216d

Please sign in to comment.