diff --git a/test/torchaudio_unittest/io/stream_writer_test.py b/test/torchaudio_unittest/io/stream_writer_test.py index 4f20e6ec36..59938d4dd4 100644 --- a/test/torchaudio_unittest/io/stream_writer_test.py +++ b/test/torchaudio_unittest/io/stream_writer_test.py @@ -4,6 +4,7 @@ from parameterized import parameterized, parameterized_class from torchaudio_unittest.common_utils import ( get_asset_path, + get_sinusoid, is_ffmpeg_available, nested_params, rgb_to_yuv_ccir, @@ -293,28 +294,58 @@ def test_video_num_frames(self, framerate, resolution, format): pass @nested_params( - ["wav", "mp3", "flac"], + ["wav", "flac"], [8000, 16000, 44100], [1, 2], ) - def test_audio_num_frames(self, ext, sample_rate, num_channels): - """""" + def test_audio_num_frames_lossless(self, ext, sample_rate, num_channels): + """Lossless format preserves the data""" filename = f"test.{ext}" + data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, dtype="int16", channels_first=False) + # Write data dst = self.get_dst(filename) s = torchaudio.io.StreamWriter(dst=dst, format=ext) - s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels) + s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, format="s16") + with s.open(): + s.write_audio_chunk(0, data) - freq = 300 - duration = 60 - theta = torch.linspace(0, freq * 2 * 3.14 * duration, sample_rate * duration) - if num_channels == 1: - chunk = torch.sin(theta).unsqueeze(-1) - else: - chunk = torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1) + if self.test_fileobj: + dst.flush() + + # Load data + s = torchaudio.io.StreamReader(src=self.get_temp_path(filename)) + s.add_audio_stream(-1) + s.process_all_packets() + (saved,) = s.pop_chunks() + + self.assertEqual(saved, data) + + @parameterized.expand( + [ + ("mp3", 1, 8000), + ("mp3", 1, 16000), + ("mp3", 1, 44100), + ("mp3", 2, 8000), + ("mp3", 2, 16000), + ("mp3", 2, 44100), + ("opus", 1, 48000), + ] + ) + def test_audio_num_frames_lossy(self, ext, num_channels, sample_rate): + """Saving audio preserves the number of channels and frames""" + filename = f"test.{ext}" + + data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False) + + # Write data + dst = self.get_dst(filename) + s = torchaudio.io.StreamWriter(dst=dst, format=ext) + s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels) with s.open(): - s.write_audio_chunk(0, chunk) + s.write_audio_chunk(0, data) + if self.test_fileobj: dst.flush() @@ -324,9 +355,21 @@ def test_audio_num_frames(self, ext, sample_rate, num_channels): s.process_all_packets() (saved,) = s.pop_chunks() - assert saved.shape == chunk.shape - if format in ["wav", "flac"]: - self.assertEqual(saved, chunk) + # This test fails for OPUS if FFmpeg is 4.1, but it passes for 4.2+ + # 4.1 produces 48312 samples (extra 312) + # Probably this commit fixes it. + # https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c + # TODO: issue warning if 4.1? + if ext == "opus": + ver = torchaudio.utils.ffmpeg_utils.get_versions()["libavcodec"] + # 5.1 libavcodec 59. 18.100 + # 4.4 libavcodec 58.134.100 + # 4.3 libavcodec 58. 91.100 + # 4.2 libavcodec 58. 54.100 + # 4.1 libavcodec 58. 35.100 + if ver[0] < 59 and ver[1] < 54: + return + self.assertEqual(saved.shape, data.shape) def test_preserve_fps(self): """Decimal point frame rate is properly saved diff --git a/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp b/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp index d91904a56f..146f873a8e 100644 --- a/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp +++ b/torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp @@ -409,6 +409,9 @@ std::unique_ptr _get_audio_filter( AVCodecContextPtr& ctx) { std::stringstream desc; desc << "aformat=" << av_get_sample_fmt_name(ctx->sample_fmt); + if (ctx->frame_size) { + desc << ",asetnsamples=n=" << ctx->frame_size << ":p=0"; + } auto p = std::make_unique(AVMEDIA_TYPE_AUDIO); p->add_audio_src(fmt, ctx->time_base, ctx->sample_rate, ctx->channel_layout);