Skip to content

Commit

Permalink
Fixing bug in unsort_tensor (#6320)
Browse files Browse the repository at this point in the history
* Fixing bug in unsort_tensor

Signed-off-by: Boris Fomitchev <[email protected]>

* docstrings added

Signed-off-by: Boris Fomitchev <[email protected]>

---------

Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom authored and Slyne Deng committed Apr 3, 2023
1 parent e2224fb commit 6c0c064
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
23 changes: 21 additions & 2 deletions nemo/collections/tts/parts/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,33 @@ def get_mask_from_lengths(lengths: Optional[torch.Tensor] = None, x: Optional[to
def sort_tensor(
context: torch.Tensor, lens: torch.Tensor, dim: Optional[int] = 0, descending: Optional[bool] = True
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Sorts elements in context by the dim lengths specified in lens
Args:
context: source tensor, sorted by lens
lens: lengths of elements of context along the dimension dim
dim: Optional[int] : dimension to sort by
Returns:
context: tensor sorted by lens along dimension dim
lens_sorted: lens tensor, sorted
ids_sorted: reorder ids to be used to restore original order
"""
lens_sorted, ids_sorted = torch.sort(lens, descending=descending)
context = torch.index_select(context, dim, ids_sorted)
return context, lens_sorted, ids_sorted


def unsort_tensor(ordered: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = 0) -> torch.Tensor:
unsort_ids = indices.gather(0, indices.argsort(0, descending=True))
return torch.index_select(ordered, dim, unsort_ids)
"""Reverses the result of sort_tensor function:
o, _, ids = sort_tensor(x,l)
assert unsort_tensor(o,ids) == x
Args:
ordered: context tensor, sorted by lengths
indices: torch.tensor: 1D tensor with 're-order' indices returned by sort_tensor
Returns:
ordered tensor in original order (before calling sort_tensor)
"""
return torch.index_select(ordered, dim, indices.argsort(0))


@jit(nopython=True)
Expand Down
16 changes: 13 additions & 3 deletions tests/collections/tts/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import pytest
import torch

from nemo.collections.tts.parts.utils.helpers import regulate_len
from nemo.collections.tts.parts.utils.helpers import regulate_len, sort_tensor, unsort_tensor


def sample_duration_input(max_length=64, group_size=2):
def sample_duration_input(max_length=64, group_size=2, batch_size=3):
generator = torch.Generator()
generator.manual_seed(0)
batch_size = 3
lengths = torch.randint(max_length // 4, max_length - 7, (batch_size,), generator=generator)
durs = torch.ones(batch_size, max_length) * group_size
durs[0, lengths[0]] += 1
Expand All @@ -30,6 +29,17 @@ def sample_duration_input(max_length=64, group_size=2):
return durs, enc, lengths


@pytest.mark.unit
def test_sort_unsort():
durs_in, enc_in, dur_lens = sample_duration_input(batch_size=13)
print("In: ", enc_in)
sorted_enc, sorted_len, sorted_ids = sort_tensor(enc_in, dur_lens)
unsorted_enc = unsort_tensor(sorted_enc, sorted_ids)
print("Out: ", unsorted_enc)

assert torch.all(unsorted_enc == enc_in)


@pytest.mark.unit
def test_regulate_len():
group_size = 2
Expand Down

0 comments on commit 6c0c064

Please sign in to comment.