diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index dc65a8cfa8..826e075fda 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -1660,6 +1660,9 @@ def vad( flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns break # end for window + if not has_triggered: + return waveform[..., :0].view(shape[:-1] + torch.Size([0])) + res = waveform[:, pos - samplesLen_ns + flushedLen_ns :] # unpack batch return res.view(shape[:-1] + res.shape[-1:]) diff --git a/test/torchaudio_unittest/transforms/transforms_test_impl.py b/test/torchaudio_unittest/transforms/transforms_test_impl.py index e6726a3183..2e70ab4ad3 100644 --- a/test/torchaudio_unittest/transforms/transforms_test_impl.py +++ b/test/torchaudio_unittest/transforms/transforms_test_impl.py @@ -478,3 +478,18 @@ def test_specaugment(self, n_time_masks, time_mask_param, n_freq_masks, freq_mas self.assertTrue(diff > 0) else: self.assertTrue(diff == 0) + + @parameterized.expand( + [ + ((32000,), (0,), 16000), + ((1, 32000), (1, 0), 32000), + ((2, 44100), (2, 0), 32000), + ((2, 2, 44100), (2, 2, 0), 32000), + ] + ) + def test_vad_on_zero_audio(self, input_shape, output_shape, sample_rate: int): + """VAD should return zero when input is zero Tensor""" + inpt = torch.zeros(input_shape, dtype=self.dtype, device=self.device) + expected_output = torch.zeros(output_shape, dtype=self.dtype, device=self.device) + result = T.Vad(sample_rate)(inpt) + self.assertEqual(result, expected_output)