Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ragged batching changes for RadTTS, some refactoring #6020

Merged
merged 1 commit into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def sort_tensor(


def unsort_tensor(ordered: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = 0) -> torch.Tensor:
unsort_ids = indices.gather(0, indices.argsort(0))
unsort_ids = indices.gather(0, indices.argsort(0, descending=True))
return torch.index_select(ordered, dim, unsort_ids)


Expand Down Expand Up @@ -706,3 +706,86 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor):
raise ValueError("Can only mask tensors of shape B x D x L and B x D1 x D2 x L")

return tensor * mask


@torch.jit.script
def batch_from_ragged(
text: torch.Tensor,
pitch: torch.Tensor,
pace: torch.Tensor,
batch_lengths: torch.Tensor,
padding_idx: int = -1,
volume: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

batch_lengths = batch_lengths.to(dtype=torch.int64)
max_len = torch.max(batch_lengths[1:] - batch_lengths[:-1])

index = 1
num_batches = batch_lengths.shape[0] - 1
texts = torch.zeros(num_batches, max_len, dtype=torch.int64, device=text.device) + padding_idx
pitches = torch.ones(num_batches, max_len, dtype=torch.float32, device=text.device)
paces = torch.zeros(num_batches, max_len, dtype=torch.float32, device=text.device) + 1.0
volumes = torch.zeros(num_batches, max_len, dtype=torch.float32, device=text.device) + 1.0
lens = torch.zeros(num_batches, dtype=torch.int64, device=text.device)
last_index = index - 1
while index < batch_lengths.shape[0]:
seq_start = batch_lengths[last_index]
seq_end = batch_lengths[index]
cur_seq_len = seq_end - seq_start
lens[last_index] = cur_seq_len
texts[last_index, :cur_seq_len] = text[seq_start:seq_end]
pitches[last_index, :cur_seq_len] = pitch[seq_start:seq_end]
paces[last_index, :cur_seq_len] = pace[seq_start:seq_end]
if volume is not None:
volumes[last_index, :cur_seq_len] = volume[seq_start:seq_end]
last_index = index
index += 1

return texts, pitches, paces, volumes, lens


def sample_tts_input(
export_config, device, max_batch=1, max_dim=127,
):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
sz = (max_batch * max_dim,) if export_config["enable_ragged_batches"] else (max_batch, max_dim)
inp = torch.randint(*export_config["emb_range"], sz, device=device, dtype=torch.int64)
pitch = torch.randn(sz, device=device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs = {'text': inp, 'pitch': pitch, 'pace': pace}
if export_config["enable_ragged_batches"]:
batch_lengths = torch.zeros((max_batch + 1), device=device, dtype=torch.int32)
left_over_size = sz[0]
batch_lengths[0] = 0
for i in range(1, max_batch):
equal_len = (left_over_size - (max_batch - i)) // (max_batch - i)
length = torch.randint(equal_len // 2, equal_len, (1,), device=device, dtype=torch.int32)
batch_lengths[i] = length + batch_lengths[i - 1]
left_over_size -= length.detach().cpu().numpy()[0]
batch_lengths[-1] = left_over_size + batch_lengths[-2]

sum = 0
index = 1
while index < len(batch_lengths):
sum += batch_lengths[index] - batch_lengths[index - 1]
index += 1
assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"
else:
batch_lengths = torch.randint(max_dim // 2, max_dim, (max_batch,), device=device, dtype=torch.int32)
batch_lengths[0] = max_dim
inputs['batch_lengths'] = batch_lengths

if export_config["enable_volume"]:
volume = torch.clamp(torch.randn(sz, device=device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs['volume'] = volume

if "num_speakers" in export_config:
inputs['speaker'] = torch.randint(
0, export_config["num_speakers"], (max_batch,), device=device, dtype=torch.int64
)
return inputs
90 changes: 18 additions & 72 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from pytorch_lightning.loggers import TensorBoardLogger

from nemo.collections.common.parts.preprocessing import parsers
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy, process_batch
from nemo.collections.tts.helpers.helpers import (
batch_from_ragged,
plot_alignment_to_numpy,
plot_spectrogram_to_numpy,
process_batch,
sample_tts_input,
)
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, EnergyLoss, MelLoss, PitchLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
Expand Down Expand Up @@ -156,7 +162,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
speaker_emb_condition_aligner,
)
self._input_types = self._output_types = None
self.export_config = {"enable_volume": False, "enable_ragged_batches": False}
self.export_config = {
"emb_range": (0, self.fastpitch.encoder.word_emb.num_embeddings),
"enable_volume": False,
"enable_ragged_batches": False,
}
if self.fastpitch.speaker_emb is not None:
self.export_config["num_speakers"] = cfg.n_speakers

def _get_default_text_tokenizer_conf(self):
text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
Expand Down Expand Up @@ -665,46 +677,14 @@ def input_example(self, max_batch=1, max_dim=44):
A tuple of input examples.
"""
par = next(self.fastpitch.parameters())
sz = (max_batch * max_dim,) if self.export_config["enable_ragged_batches"] else (max_batch, max_dim)
inp = torch.randint(
0, self.fastpitch.encoder.word_emb.num_embeddings, sz, device=par.device, dtype=torch.int64
)
pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)

inputs = {'text': inp, 'pitch': pitch, 'pace': pace}

if self.export_config["enable_volume"]:
volume = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs['volume'] = volume
if self.export_config["enable_ragged_batches"]:
batch_lengths = torch.zeros((max_batch + 1), device=par.device, dtype=torch.int32)
left_over_size = sz[0]
batch_lengths[0] = 0
for i in range(1, max_batch):
length = torch.randint(1, left_over_size - (max_batch - i), (1,), device=par.device)
batch_lengths[i] = length + batch_lengths[i - 1]
left_over_size -= length.detach().cpu().numpy()[0]
batch_lengths[-1] = left_over_size + batch_lengths[-2]

sum = 0
index = 1
while index < len(batch_lengths):
sum += batch_lengths[index] - batch_lengths[index - 1]
index += 1
assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"
inputs['batch_lengths'] = batch_lengths

if self.fastpitch.speaker_emb is not None:
inputs['speaker'] = torch.randint(
0, self.fastpitch.speaker_emb.num_embeddings, (max_batch,), device=par.device, dtype=torch.int64
)

inputs = sample_tts_input(self.export_config, par.device, max_batch=max_batch, max_dim=max_dim)
if 'enable_ragged_batches' not in self.export_config:
inputs.pop('batch_lengths', None)
return (inputs,)

def forward_for_export(self, text, pitch, pace, volume=None, batch_lengths=None, speaker=None):
if self.export_config["enable_ragged_batches"]:
text, pitch, pace, volume_tensor = create_batch(
text, pitch, pace, volume_tensor, lens = batch_from_ragged(
text, pitch, pace, batch_lengths, padding_idx=self.fastpitch.encoder.padding_idx, volume=volume
)
if volume is not None:
Expand Down Expand Up @@ -743,37 +723,3 @@ def interpolate_speaker(
)
new_speaker_emb = weight_speaker_1 * speaker_emb_1 + weight_speaker_2 * speaker_emb_2
self.fastpitch.speaker_emb.weight.data[new_speaker_id] = new_speaker_emb


@torch.jit.script
def create_batch(
text: torch.Tensor,
pitch: torch.Tensor,
pace: torch.Tensor,
batch_lengths: torch.Tensor,
padding_idx: int = -1,
volume: Optional[torch.Tensor] = None,
):
batch_lengths = batch_lengths.to(torch.int64)
max_len = torch.max(batch_lengths[1:] - batch_lengths[:-1])

index = 1
texts = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.int64, device=text.device) + padding_idx
pitches = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device)
paces = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0
volumes = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0

while index < batch_lengths.shape[0]:
seq_start = batch_lengths[index - 1]
seq_end = batch_lengths[index]
cur_seq_len = seq_end - seq_start

texts[index - 1, :cur_seq_len] = text[seq_start:seq_end]
pitches[index - 1, :cur_seq_len] = pitch[seq_start:seq_end]
paces[index - 1, :cur_seq_len] = pace[seq_start:seq_end]
if volume is not None:
volumes[index - 1, :cur_seq_len] = volume[seq_start:seq_end]

index += 1

return texts, pitches, paces, volumes
114 changes: 56 additions & 58 deletions nemo/collections/tts/models/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
from pytorch_lightning.loggers import TensorBoardLogger

from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import BaseTokenizer
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, regulate_len
from nemo.collections.tts.helpers.helpers import (
batch_from_ragged,
plot_alignment_to_numpy,
regulate_len,
sample_tts_input,
)
from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
from nemo.core.classes import Exportable
Expand Down Expand Up @@ -81,6 +86,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self._tb_logger = None
self.cfg = cfg
self.log_train_images = False
self.export_config = {
"emb_range": (32, 64),
"enable_volume": True,
"enable_ragged_batches": False,
"num_speakers": self.model_config.n_speakers,
}
# print("intial self normalizer", self.normalizer)

def batch_dict(self, batch_data):
Expand Down Expand Up @@ -420,77 +431,60 @@ def _prepare_for_export(self, **kwargs):
self.model.remove_norms()
super()._prepare_for_export(**kwargs)

tensor_shape = ('T') if self.export_config["enable_ragged_batches"] else ('B', 'T')

# Define input_types and output_types as required by export()
self._input_types = {
"text": NeuralType(('B', 'T'), TokenIndex()),
"text": NeuralType(tensor_shape, TokenIndex()),
"lens": NeuralType(('B')),
"speaker_id": NeuralType(('B'), Index()),
"speaker_id_text": NeuralType(('B'), Index()),
"speaker_id_attributes": NeuralType(('B'), Index()),
"pitch": NeuralType(('B', 'T'), RegressionValuesType()),
"pace": NeuralType(('B', 'T')),
"volume": NeuralType(('B', 'T'), optional=True),
"pitch": NeuralType(tensor_shape, RegressionValuesType()),
"pace": NeuralType(tensor_shape),
}
self._output_types = {
"spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()),
"num_frames": NeuralType(('B'), TokenDurationType()),
"durs_predicted": NeuralType(('B', 'T_text'), TokenDurationType()),
"volume_aligned": NeuralType(('B', 'T_spec'), RegressionValuesType()),
}
if self.export_config["enable_volume"]:
self._input_types["volume"] = NeuralType(tensor_shape, optional=True)
self._output_types["volume_aligned"] = NeuralType(('B', 'T_spec'), RegressionValuesType())

def input_example(self, max_batch=1, max_dim=400):
par = next(self.parameters())
sz = (max_batch, max_dim)
# sz = (max_batch * max_dim,)
# Pick up only pronouncible tokens
inp = torch.randint(32, 64, sz, device=par.device, dtype=torch.int64)
speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64)
pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.2, max=2.0)
volume = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.2, max=2.0)
# batch_lengths = torch.zeros((max_batch + 1), device=par.device, dtype=torch.int64)
# left_over_size = sz[0]
# batch_lengths[0] = 0
# for i in range(1, max_batch):
# length = torch.randint(1, left_over_size - (max_batch - i), (1,), device=par.device)
# batch_lengths[i] = length + batch_lengths[i - 1]
# left_over_size -= length.detach().cpu().numpy()[0]
# batch_lengths[-1] = left_over_size + batch_lengths[-2]
# sum = 0
# index = 1
# while index < len(batch_lengths):
# sum += batch_lengths[index] - batch_lengths[index - 1]
# index += 1
# assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"

par = next(self.model.parameters())
inputs = sample_tts_input(self.export_config, par.device, max_batch=max_batch, max_dim=max_dim)
speaker = inputs["speaker"]
inp = inputs['text']
pad_id = self.tokenizer.pad
inp[inp == pad_id] = pad_id - 1 if pad_id > 0 else pad_id + 1

lens = []
for i, _ in enumerate(inp):
len_i = random.randint(64, max_dim)
lens.append(len_i)
# inp[i, len_i:] = pad_id
lens = torch.tensor(lens, device=par.device, dtype=torch.int32)
lens[0] = max_dim

inputs = {
new_inputs = {
'text': inp,
'lens': lens,
# 'batch_lengths': batch_lengths,
'lens': inputs['batch_lengths'],
'speaker_id': speaker,
'speaker_id_text': speaker,
'speaker_id_attributes': speaker,
'pitch': pitch,
'pace': pace,
'volume': volume,
'pitch': inputs['pitch'],
'pace': inputs['pace'],
'volume': inputs['volume'],
}
return (inputs,)

return (new_inputs,)

def forward_for_export(
self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes, pitch, pace, volume,
):
lens = lens.to(dtype=torch.int64)
if self.export_config["enable_ragged_batches"]:
text, pitch, pace, volume_tensor, lens = batch_from_ragged(
text, pitch, pace, batch_lengths=lens, padding_idx=self.tokenizer_pad, volume=volume,
)
if volume is not None:
volume = volume_tensor
else:
lens = lens.to(dtype=torch.int64)

(mel, n_frames, dur, _, _) = self.model.infer(
speaker_id,
text,
Expand All @@ -506,15 +500,19 @@ def forward_for_export(
pitch_shift=pitch,
pace=pace,
).values()
# Need to reshape as in infer patch
durs_predicted = dur.float()
truncated_length = torch.max(lens)
volume_extended, _ = regulate_len(
durs_predicted,
volume[:, :truncated_length].unsqueeze(-1),
pace[:, :truncated_length],
group_size=self.model.n_group_size,
dur_lens=lens,
)
volume_extended = volume_extended.squeeze(2).float()
return mel.float(), n_frames, dur.float(), volume_extended
ret_values = (mel.float(), n_frames, dur.float())

if volume is not None:
# Need to reshape as in infer patch
durs_predicted = dur.float()
truncated_length = torch.max(lens)
volume_extended, _ = regulate_len(
durs_predicted,
volume[:, :truncated_length].unsqueeze(-1),
pace[:, :truncated_length],
group_size=self.model.n_group_size,
dur_lens=lens,
)
volume_extended = volume_extended.squeeze(2).float()
ret_values = ret_values + (volume_extended,)
return ret_values
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/radtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def infer(
dur = self.dur_pred_layer.infer(txt_enc, spk_vec_text, lens=in_lens)
dur = pad_dur(dur, txt_enc)
dur = dur[:, 0]
dur = dur.clamp(1, token_duration_max)
dur = dur.clamp(0, token_duration_max)

if pace is None:
pace = txt_enc.new_ones((batch_size, txt_len_pad_removed))
Expand Down
Loading