Skip to content

Commit

Permalink
Ragged batching changes for RadTTS, some refactoring (NVIDIA#6020)
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom authored and titu1994 committed Mar 24, 2023
1 parent 26ddaa3 commit 6bd21ce
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 134 deletions.
85 changes: 84 additions & 1 deletion nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def sort_tensor(


def unsort_tensor(ordered: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = 0) -> torch.Tensor:
unsort_ids = indices.gather(0, indices.argsort(0))
unsort_ids = indices.gather(0, indices.argsort(0, descending=True))
return torch.index_select(ordered, dim, unsort_ids)


Expand Down Expand Up @@ -706,3 +706,86 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor):
raise ValueError("Can only mask tensors of shape B x D x L and B x D1 x D2 x L")

return tensor * mask


@torch.jit.script
def batch_from_ragged(
text: torch.Tensor,
pitch: torch.Tensor,
pace: torch.Tensor,
batch_lengths: torch.Tensor,
padding_idx: int = -1,
volume: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

batch_lengths = batch_lengths.to(dtype=torch.int64)
max_len = torch.max(batch_lengths[1:] - batch_lengths[:-1])

index = 1
num_batches = batch_lengths.shape[0] - 1
texts = torch.zeros(num_batches, max_len, dtype=torch.int64, device=text.device) + padding_idx
pitches = torch.ones(num_batches, max_len, dtype=torch.float32, device=text.device)
paces = torch.zeros(num_batches, max_len, dtype=torch.float32, device=text.device) + 1.0
volumes = torch.zeros(num_batches, max_len, dtype=torch.float32, device=text.device) + 1.0
lens = torch.zeros(num_batches, dtype=torch.int64, device=text.device)
last_index = index - 1
while index < batch_lengths.shape[0]:
seq_start = batch_lengths[last_index]
seq_end = batch_lengths[index]
cur_seq_len = seq_end - seq_start
lens[last_index] = cur_seq_len
texts[last_index, :cur_seq_len] = text[seq_start:seq_end]
pitches[last_index, :cur_seq_len] = pitch[seq_start:seq_end]
paces[last_index, :cur_seq_len] = pace[seq_start:seq_end]
if volume is not None:
volumes[last_index, :cur_seq_len] = volume[seq_start:seq_end]
last_index = index
index += 1

return texts, pitches, paces, volumes, lens


def sample_tts_input(
export_config, device, max_batch=1, max_dim=127,
):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
sz = (max_batch * max_dim,) if export_config["enable_ragged_batches"] else (max_batch, max_dim)
inp = torch.randint(*export_config["emb_range"], sz, device=device, dtype=torch.int64)
pitch = torch.randn(sz, device=device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs = {'text': inp, 'pitch': pitch, 'pace': pace}
if export_config["enable_ragged_batches"]:
batch_lengths = torch.zeros((max_batch + 1), device=device, dtype=torch.int32)
left_over_size = sz[0]
batch_lengths[0] = 0
for i in range(1, max_batch):
equal_len = (left_over_size - (max_batch - i)) // (max_batch - i)
length = torch.randint(equal_len // 2, equal_len, (1,), device=device, dtype=torch.int32)
batch_lengths[i] = length + batch_lengths[i - 1]
left_over_size -= length.detach().cpu().numpy()[0]
batch_lengths[-1] = left_over_size + batch_lengths[-2]

sum = 0
index = 1
while index < len(batch_lengths):
sum += batch_lengths[index] - batch_lengths[index - 1]
index += 1
assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"
else:
batch_lengths = torch.randint(max_dim // 2, max_dim, (max_batch,), device=device, dtype=torch.int32)
batch_lengths[0] = max_dim
inputs['batch_lengths'] = batch_lengths

if export_config["enable_volume"]:
volume = torch.clamp(torch.randn(sz, device=device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs['volume'] = volume

if "num_speakers" in export_config:
inputs['speaker'] = torch.randint(
0, export_config["num_speakers"], (max_batch,), device=device, dtype=torch.int64
)
return inputs
90 changes: 18 additions & 72 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from pytorch_lightning.loggers import TensorBoardLogger

from nemo.collections.common.parts.preprocessing import parsers
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy, process_batch
from nemo.collections.tts.helpers.helpers import (
batch_from_ragged,
plot_alignment_to_numpy,
plot_spectrogram_to_numpy,
process_batch,
sample_tts_input,
)
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, EnergyLoss, MelLoss, PitchLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
Expand Down Expand Up @@ -156,7 +162,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
speaker_emb_condition_aligner,
)
self._input_types = self._output_types = None
self.export_config = {"enable_volume": False, "enable_ragged_batches": False}
self.export_config = {
"emb_range": (0, self.fastpitch.encoder.word_emb.num_embeddings),
"enable_volume": False,
"enable_ragged_batches": False,
}
if self.fastpitch.speaker_emb is not None:
self.export_config["num_speakers"] = cfg.n_speakers

def _get_default_text_tokenizer_conf(self):
text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
Expand Down Expand Up @@ -665,46 +677,14 @@ def input_example(self, max_batch=1, max_dim=44):
A tuple of input examples.
"""
par = next(self.fastpitch.parameters())
sz = (max_batch * max_dim,) if self.export_config["enable_ragged_batches"] else (max_batch, max_dim)
inp = torch.randint(
0, self.fastpitch.encoder.word_emb.num_embeddings, sz, device=par.device, dtype=torch.int64
)
pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)

inputs = {'text': inp, 'pitch': pitch, 'pace': pace}

if self.export_config["enable_volume"]:
volume = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs['volume'] = volume
if self.export_config["enable_ragged_batches"]:
batch_lengths = torch.zeros((max_batch + 1), device=par.device, dtype=torch.int32)
left_over_size = sz[0]
batch_lengths[0] = 0
for i in range(1, max_batch):
length = torch.randint(1, left_over_size - (max_batch - i), (1,), device=par.device)
batch_lengths[i] = length + batch_lengths[i - 1]
left_over_size -= length.detach().cpu().numpy()[0]
batch_lengths[-1] = left_over_size + batch_lengths[-2]

sum = 0
index = 1
while index < len(batch_lengths):
sum += batch_lengths[index] - batch_lengths[index - 1]
index += 1
assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"
inputs['batch_lengths'] = batch_lengths

if self.fastpitch.speaker_emb is not None:
inputs['speaker'] = torch.randint(
0, self.fastpitch.speaker_emb.num_embeddings, (max_batch,), device=par.device, dtype=torch.int64
)

inputs = sample_tts_input(self.export_config, par.device, max_batch=max_batch, max_dim=max_dim)
if 'enable_ragged_batches' not in self.export_config:
inputs.pop('batch_lengths', None)
return (inputs,)

def forward_for_export(self, text, pitch, pace, volume=None, batch_lengths=None, speaker=None):
if self.export_config["enable_ragged_batches"]:
text, pitch, pace, volume_tensor = create_batch(
text, pitch, pace, volume_tensor, lens = batch_from_ragged(
text, pitch, pace, batch_lengths, padding_idx=self.fastpitch.encoder.padding_idx, volume=volume
)
if volume is not None:
Expand Down Expand Up @@ -743,37 +723,3 @@ def interpolate_speaker(
)
new_speaker_emb = weight_speaker_1 * speaker_emb_1 + weight_speaker_2 * speaker_emb_2
self.fastpitch.speaker_emb.weight.data[new_speaker_id] = new_speaker_emb


@torch.jit.script
def create_batch(
text: torch.Tensor,
pitch: torch.Tensor,
pace: torch.Tensor,
batch_lengths: torch.Tensor,
padding_idx: int = -1,
volume: Optional[torch.Tensor] = None,
):
batch_lengths = batch_lengths.to(torch.int64)
max_len = torch.max(batch_lengths[1:] - batch_lengths[:-1])

index = 1
texts = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.int64, device=text.device) + padding_idx
pitches = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device)
paces = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0
volumes = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0

while index < batch_lengths.shape[0]:
seq_start = batch_lengths[index - 1]
seq_end = batch_lengths[index]
cur_seq_len = seq_end - seq_start

texts[index - 1, :cur_seq_len] = text[seq_start:seq_end]
pitches[index - 1, :cur_seq_len] = pitch[seq_start:seq_end]
paces[index - 1, :cur_seq_len] = pace[seq_start:seq_end]
if volume is not None:
volumes[index - 1, :cur_seq_len] = volume[seq_start:seq_end]

index += 1

return texts, pitches, paces, volumes
114 changes: 56 additions & 58 deletions nemo/collections/tts/models/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from pytorch_lightning.loggers import TensorBoardLogger

from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import BaseTokenizer
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, regulate_len
from nemo.collections.tts.helpers.helpers import (
batch_from_ragged,
plot_alignment_to_numpy,
regulate_len,
sample_tts_input,
)
from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.core.classes import Exportable
Expand Down Expand Up @@ -81,6 +86,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._tb_logger = None
self.cfg = cfg
self.log_train_images = False
self.export_config = {
"emb_range": (32, 64),
"enable_volume": True,
"enable_ragged_batches": False,
"num_speakers": self.model_config.n_speakers,
}
# print("intial self normalizer", self.normalizer)

def batch_dict(self, batch_data):
Expand Down Expand Up @@ -420,77 +431,60 @@ def _prepare_for_export(self, **kwargs):
self.model.remove_norms()
super()._prepare_for_export(**kwargs)

tensor_shape = ('T') if self.export_config["enable_ragged_batches"] else ('B', 'T')

# Define input_types and output_types as required by export()
self._input_types = {
"text": NeuralType(('B', 'T'), TokenIndex()),
"text": NeuralType(tensor_shape, TokenIndex()),
"lens": NeuralType(('B')),
"speaker_id": NeuralType(('B'), Index()),
"speaker_id_text": NeuralType(('B'), Index()),
"speaker_id_attributes": NeuralType(('B'), Index()),
"pitch": NeuralType(('B', 'T'), RegressionValuesType()),
"pace": NeuralType(('B', 'T')),
"volume": NeuralType(('B', 'T'), optional=True),
"pitch": NeuralType(tensor_shape, RegressionValuesType()),
"pace": NeuralType(tensor_shape),
}
self._output_types = {
"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
"num_frames": NeuralType(('B'), TokenDurationType()),
"durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()),
"volume_aligned": NeuralType(('B', 'T_spec'), RegressionValuesType()),
}
if self.export_config["enable_volume"]:
self._input_types["volume"] = NeuralType(tensor_shape, optional=True)
self._output_types["volume_aligned"] = NeuralType(('B', 'T_spec'), RegressionValuesType())

def input_example(self, max_batch=1, max_dim=400):
par = next(self.parameters())
sz = (max_batch, max_dim)
# sz = (max_batch * max_dim,)
# Pick up only pronouncible tokens
inp = torch.randint(32, 64, sz, device=par.device, dtype=torch.int64)
speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64)
pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.2, max=2.0)
volume = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.2, max=2.0)
# batch_lengths = torch.zeros((max_batch + 1), device=par.device, dtype=torch.int64)
# left_over_size = sz[0]
# batch_lengths[0] = 0
# for i in range(1, max_batch):
# length = torch.randint(1, left_over_size - (max_batch - i), (1,), device=par.device)
# batch_lengths[i] = length + batch_lengths[i - 1]
# left_over_size -= length.detach().cpu().numpy()[0]
# batch_lengths[-1] = left_over_size + batch_lengths[-2]
# sum = 0
# index = 1
# while index < len(batch_lengths):
# sum += batch_lengths[index] - batch_lengths[index - 1]
# index += 1
# assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"

par = next(self.model.parameters())
inputs = sample_tts_input(self.export_config, par.device, max_batch=max_batch, max_dim=max_dim)
speaker = inputs["speaker"]
inp = inputs['text']
pad_id = self.tokenizer.pad
inp[inp == pad_id] = pad_id - 1 if pad_id > 0 else pad_id + 1

lens = []
for i, _ in enumerate(inp):
len_i = random.randint(64, max_dim)
lens.append(len_i)
# inp[i, len_i:] = pad_id
lens = torch.tensor(lens, device=par.device, dtype=torch.int32)
lens[0] = max_dim

inputs = {
new_inputs = {
'text': inp,
'lens': lens,
# 'batch_lengths': batch_lengths,
'lens': inputs['batch_lengths'],
'speaker_id': speaker,
'speaker_id_text': speaker,
'speaker_id_attributes': speaker,
'pitch': pitch,
'pace': pace,
'volume': volume,
'pitch': inputs['pitch'],
'pace': inputs['pace'],
'volume': inputs['volume'],
}
return (inputs,)

return (new_inputs,)

def forward_for_export(
self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes, pitch, pace, volume,
):
lens = lens.to(dtype=torch.int64)
if self.export_config["enable_ragged_batches"]:
text, pitch, pace, volume_tensor, lens = batch_from_ragged(
text, pitch, pace, batch_lengths=lens, padding_idx=self.tokenizer_pad, volume=volume,
)
if volume is not None:
volume = volume_tensor
else:
lens = lens.to(dtype=torch.int64)

(mel, n_frames, dur, _, _) = self.model.infer(
speaker_id,
text,
Expand All @@ -506,15 +500,19 @@ def forward_for_export(
pitch_shift=pitch,
pace=pace,
).values()
# Need to reshape as in infer patch
durs_predicted = dur.float()
truncated_length = torch.max(lens)
volume_extended, _ = regulate_len(
durs_predicted,
volume[:, :truncated_length].unsqueeze(-1),
pace[:, :truncated_length],
group_size=self.model.n_group_size,
dur_lens=lens,
)
volume_extended = volume_extended.squeeze(2).float()
return mel.float(), n_frames, dur.float(), volume_extended
ret_values = (mel.float(), n_frames, dur.float())

if volume is not None:
# Need to reshape as in infer patch
durs_predicted = dur.float()
truncated_length = torch.max(lens)
volume_extended, _ = regulate_len(
durs_predicted,
volume[:, :truncated_length].unsqueeze(-1),
pace[:, :truncated_length],
group_size=self.model.n_group_size,
dur_lens=lens,
)
volume_extended = volume_extended.squeeze(2).float()
ret_values = ret_values + (volume_extended,)
return ret_values
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def infer(
dur = self.dur_pred_layer.infer(txt_enc, spk_vec_text, lens=in_lens)
dur = pad_dur(dur, txt_enc)
dur = dur[:, 0]
dur = dur.clamp(1, token_duration_max)
dur = dur.clamp(0, token_duration_max)

if pace is None:
pace = txt_enc.new_ones((batch_size, txt_len_pad_removed))
Expand Down
Loading

0 comments on commit 6bd21ce

Please sign in to comment.