Skip to content

Commit

Permalink
[TTS] Add unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <[email protected]>
  • Loading branch information
rlangman committed May 9, 2023
1 parent 6bca3e6 commit a636f13
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion tests/collections/tts/parts/utils/test_tts_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
from pathlib import Path

import pytest
import torch

from nemo.collections.tts.parts.utils.tts_dataset_utils import get_abs_rel_paths, get_audio_filepaths
from nemo.collections.tts.parts.utils.tts_dataset_utils import (
filter_dataset_by_duration,
get_abs_rel_paths,
get_audio_filepaths,
stack_tensors,
)


class TestTTSDatasetUtils:
Expand Down Expand Up @@ -53,3 +59,53 @@ def test_get_audio_paths(self):

assert abs_path == Path("/home/audio/examples/example.wav")
assert rel_path == audio_rel_path

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_stack_tensors(self):
tensors = [torch.ones([2]), torch.ones([4]), torch.ones([3])]
max_lens = [6]
expected_output = torch.tensor(
[[1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0]], dtype=torch.float32
)

stacked_tensor = stack_tensors(tensors=tensors, max_lens=max_lens)

torch.testing.assert_close(stacked_tensor, expected_output)

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_stack_tensors_3d(self):
tensors = [torch.ones([2, 2]), torch.ones([1, 3])]
max_lens = [4, 2]
expected_output = torch.tensor(
[[[1, 1, 0, 0], [1, 1, 0, 0]], [[1, 1, 1, 0], [0, 0, 0, 0]]], dtype=torch.float32
)

stacked_tensor = stack_tensors(tensors=tensors, max_lens=max_lens)

torch.testing.assert_close(stacked_tensor, expected_output)

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
def test_filter_dataset_by_duration(self):
min_duration = 1.0
max_duration = 10.0
entries = [
{"duration": 0.5},
{"duration": 10.0},
{"duration": 20.0},
{"duration": 0.1},
{"duration": 100.0},
{"duration": 5.0},
]

filtered_entries, total_hours, filtered_hours = filter_dataset_by_duration(
entries=entries, min_duration=min_duration, max_duration=max_duration
)

assert len(filtered_entries) == 2
assert filtered_entries[0]["duration"] == 10.0
assert filtered_entries[1]["duration"] == 5.0
assert total_hours == (135.6 / 3600.0)
assert filtered_hours == (15.0 / 3600.0)

0 comments on commit a636f13

Please sign in to comment.