Skip to content

Commit

Permalink
fix: F.vad returned is not correct when input is audio
Browse files Browse the repository at this point in the history
  • Loading branch information
wasd96040501 committed Nov 4, 2023
1 parent 36f5010 commit c9983a8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,9 @@ def vad(
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
break
# end for window
if not has_triggered:
return waveform[..., :0].view(shape[:-1] + torch.Size([0]))

res = waveform[:, pos - samplesLen_ns + flushedLen_ns :]
# unpack batch
return res.view(shape[:-1] + res.shape[-1:])
12 changes: 12 additions & 0 deletions test/torchaudio_unittest/transforms/sox_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def test_vad(self, filename):
result = T.Vad(sample_rate)(data)
self.assert_sox_effect(result, path, ["vad"])

@parameterized.expand(
[
(torch.zeros(32000), torch.zeros(0), 16000),
(torch.zeros(1, 32000), torch.zeros(1, 0), 32000),
(torch.zeros(2, 44100), torch.zeros(2, 0), 32000),
(torch.zeros(2, 2, 44100), torch.zeros(2, 2, 0), 32000),
]
)
def test_vad_on_zero_audio(self, inpt: torch.Tensor, expected_output: torch.Tensor, sample_rate: int):
result = T.Vad(sample_rate)(inpt)
self.assertEqual(result, expected_output)

def test_vad_warning(self):
"""vad should throw a warning if input dimension is greater than 2"""
sample_rate = 41100
Expand Down

0 comments on commit c9983a8

Please sign in to comment.