Skip to content

Commit

Permalink
[Cherry-pick] Properly set #samples passed to encoder (#3204) (#3239)
Browse files Browse the repository at this point in the history
Summary:
Some audio encoders expect specific, exact number of samples described as in `AVCodecContext.frame_size`.

The `AVFrame.nb_samples` is set for the frames passed to `AVFilterGraph`,
but frames coming out of the graph do not necessarily have the same numbr of frames.

This causes issues with encoding OPUS (among others).

This commit fixes it by inserting `asetnsamples` to filter graph if a fixed number of samples is requested.

Note:
It turned out that FFmpeg 4.1 has issue with OPUS encoding. It does not properly discard some sample.
We should probably move the minimum required FFmpeg to 4.2, but I am not sure if we can enforce it via ABI.
Work around will be to issue an warning if encoding OPUS with 4.1. (follow-up)

Pull Request resolved: #3204

Reviewed By: nateanl

Differential Revision: D44374668

Pulled By: mthrok

fbshipit-source-id: 10ef5333dc0677dfb83c8e40b78edd8ded1b21dc
  • Loading branch information
mthrok authored Apr 5, 2023
1 parent 3b40834 commit 9df28ff
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 15 deletions.
73 changes: 58 additions & 15 deletions test/torchaudio_unittest/io/stream_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import (
get_asset_path,
get_sinusoid,
is_ffmpeg_available,
nested_params,
rgb_to_yuv_ccir,
Expand Down Expand Up @@ -293,28 +294,58 @@ def test_video_num_frames(self, framerate, resolution, format):
pass

@nested_params(
["wav", "mp3", "flac"],
["wav", "flac"],
[8000, 16000, 44100],
[1, 2],
)
def test_audio_num_frames(self, ext, sample_rate, num_channels):
""""""
def test_audio_num_frames_lossless(self, ext, sample_rate, num_channels):
"""Lossless format preserves the data"""
filename = f"test.{ext}"

data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, dtype="int16", channels_first=False)

# Write data
dst = self.get_dst(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, format="s16")
with s.open():
s.write_audio_chunk(0, data)

freq = 300
duration = 60
theta = torch.linspace(0, freq * 2 * 3.14 * duration, sample_rate * duration)
if num_channels == 1:
chunk = torch.sin(theta).unsqueeze(-1)
else:
chunk = torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1)
if self.test_fileobj:
dst.flush()

# Load data
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
s.add_audio_stream(-1)
s.process_all_packets()
(saved,) = s.pop_chunks()

self.assertEqual(saved, data)

@parameterized.expand(
[
("mp3", 1, 8000),
("mp3", 1, 16000),
("mp3", 1, 44100),
("mp3", 2, 8000),
("mp3", 2, 16000),
("mp3", 2, 44100),
("opus", 1, 48000),
]
)
def test_audio_num_frames_lossy(self, ext, num_channels, sample_rate):
"""Saving audio preserves the number of channels and frames"""
filename = f"test.{ext}"

data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)

# Write data
dst = self.get_dst(filename)
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
with s.open():
s.write_audio_chunk(0, chunk)
s.write_audio_chunk(0, data)

if self.test_fileobj:
dst.flush()

Expand All @@ -324,9 +355,21 @@ def test_audio_num_frames(self, ext, sample_rate, num_channels):
s.process_all_packets()
(saved,) = s.pop_chunks()

assert saved.shape == chunk.shape
if format in ["wav", "flac"]:
self.assertEqual(saved, chunk)
# This test fails for OPUS if FFmpeg is 4.1, but it passes for 4.2+
# 4.1 produces 48312 samples (extra 312)
# Probably this commit fixes it.
# https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c
# TODO: issue warning if 4.1?
if ext == "opus":
ver = torchaudio.utils.ffmpeg_utils.get_versions()["libavcodec"]
# 5.1 libavcodec 59. 18.100
# 4.4 libavcodec 58.134.100
# 4.3 libavcodec 58. 91.100
# 4.2 libavcodec 58. 54.100
# 4.1 libavcodec 58. 35.100
if ver[0] < 59 and ver[1] < 54:
return
self.assertEqual(saved.shape, data.shape)

def test_preserve_fps(self):
"""Decimal point frame rate is properly saved
Expand Down
3 changes: 3 additions & 0 deletions torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ std::unique_ptr<FilterGraph> _get_audio_filter(
AVCodecContextPtr& ctx) {
std::stringstream desc;
desc << "aformat=" << av_get_sample_fmt_name(ctx->sample_fmt);
if (ctx->frame_size) {
desc << ",asetnsamples=n=" << ctx->frame_size << ":p=0";
}

auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO);
p->add_audio_src(fmt, ctx->time_base, ctx->sample_rate, ctx->channel_layout);
Expand Down

0 comments on commit 9df28ff

Please sign in to comment.