Skip to content

Commit

Permalink
Merge branch 'main' into feat/inference_diar
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Feb 15, 2023
2 parents 5bf2273 + 49d16f9 commit 821dd11
Show file tree
Hide file tree
Showing 14 changed files with 324 additions and 181 deletions.
111 changes: 110 additions & 1 deletion docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,113 @@ If using adaptive bucketing, note that the same batch size will be assigned to e
All instances of data from `bucket4` will still be trained with a batch size of 2 while all others would have a batch size of 4. As with standard bucketing, this requires `batch_size`` to be set to 1.
If `bucketing_batch_size` is not specified, all datasets will be passed with the same fixed batch size as specified by the `batch_size` parameter.

It is recommended to set bucketing strategies to `fully_randomized` during multi-GPU training to prevent possible dataset bias during training.
It is recommended to set bucketing strategies to `fully_randomized` during multi-GPU training to prevent possible dataset bias during training.


Datasets on AIStore
-------------------

`AIStore <https://aiatscale.org>`_ is an open-source lightweight object storage system focused on large-scale deep learning.
AIStore is aimed to scale linearly with each added storage node, can be deployed on any Linux machine and can provide a unified namespace across multiple remote backends, such as Amazon S3, Google Cloud, and Microsoft Azure.
More details are provided in the `documentation <https://aiatscale.org/docs>`_ and the `repository <https://github.com/NVIDIA/aistore>`_ of the AIStore project.

NeMo currently supports datasets from an AIStore bucket provider under ``ais://`` namespace.

AIStore Setup
~~~~~~~~~~~~~

NeMo is currently relying on the AIStore (AIS) command-line interface (CLI) to handle the supported datasets.
The CLI is available in current NeMo Docker containers.
If necessary, the CLI can be configured using the instructions provided in `AIStore CLI <https://aiatscale.org/docs/cli>`_ documentation.

To start using the AIS CLI to access data on an AIS cluster, an endpoint needs to be configured.
The endpoint is configured by setting ``AIS_ENDPOINT`` environment variable before using the CLI

.. code::
export AIS_ENDPOINT=http://hostname:port
ais --help
In the above, ``hostname:port`` denotes the address of an AIS gateway.
For example, the address could be ``localhost:51080`` if testing using a local `minimal production-ready standalone Docker container <https://github.com/NVIDIA/aistore/blob/master/deploy/prod/docker/single/README.md>`_.

Dataset Setup
~~~~~~~~~~~~~

Currently, both tarred and non-tarred datasets are supported.
For any dataset, the corresponding manifest file is cached locally and processed as a regular manifest file.
For non-tarred datasets, the audio data is also cached locally.
For tarred datasets, shards from the AIS cluster are used by piping ``ais get`` to WebDataset.

Tarred Dataset from AIS
^^^^^^^^^^^^^^^^^^^^^^^

A tarred dataset can be easily used as described in the :ref:`Tarred Datasets` section by providing paths to manifests on an AIS cluster.
For example, a tarred dataset from an AIS cluster can be configured as

.. code::
manifest_filepath='ais://bucket/tarred_audio_manifest.json'
tarred_audio_filepaths='ais://bucket/shard_{1..64}.tar'
:ref:`Bucketing Datasets` are configured in a similar way by providing paths on an AIS cluster.

Non-tarred Dataset from AIS
^^^^^^^^^^^^^^^^^^^^^^^^^^^

A non-tarred dataset can be easly used by providing a manifest file path on an AIS cluster

.. code::
manifest_filepath='ais://bucket/dataset_manifest.json'
Note that it is assumed that the manifest file path contains audio file paths relative to the manifest locations.
For example the manifest file may have lines in the following format

.. code-block:: json
{"audio_filepath": "path/to/audio.wav", "text": "transcription of the uterance", "duration": 23.147}
The corresponding audio file would be downloaded from ``ais://bucket/path/to/audio.wav``.

Cache configuration
^^^^^^^^^^^^^^^^^^^

Manifests and audio files from non-tarred datasets will be cached locally.
Location of the cache can be configured by setting two environment variables

- ``NEMO_DATA_STORE_CACHE_DIR``: path to a location which can be used to cache the data
- ``NEMO_DATA_STORE_CACHE_SHARED``: flag to denote whether the cache location is shared between the compute nodes

In a multi-node environment, the cache location may or may be not shared between the nodes.
This can be configured by setting ``NEMO_DATA_STORE_CACHE_SHARED`` to ``1`` when the location is shared between the nodes or to ``0`` when each node has a separate cache.

When a globally shared cache is available, the data should be cached only once from the global rank zero node.
When a node-specific cache is used, the data should be cached only once by each local rank zero node.
To control this behavior using `torch.distributed.barrier`, instantiation of the corresponding dataloader needs to be deferred ``ModelPT::setup``, to ensure a distributed environment has been initialized.
This can be achieved by setting ``defer_setup`` as

.. code:: shell
++model.train_ds.defer_setup=true
++model.validation_ds.defer_setup=true
++model.test_ds.defer_setup=true
Complete Example
^^^^^^^^^^^^^^^^

An example using an AIS cluster at ``hostname:port`` with a tarred dataset for training, a non-tarred dataset for validation and node-specific caching is given below

.. code:: shell
export AIS_ENDPOINT=http://hostname:port \
&& export NEMO_DATA_STORE_CACHE_DIR=/tmp \
&& export NEMO_DATA_STORE_CACHE_SHARED=0 \
python speech_to_text_bpe.py \
...
model.train_ds.manifest_filepath=ais://train_bucket/tarred_audio_manifest.json \
model.train_ds.tarred_audio_filepaths=ais://train_bucket/audio__OP_0..511_CL_.tar \
++model.train_ds.defer_setup=true \
mode.validation_ds.manifest_filepath=ais://validation_bucket/validation_manifest.json \
++model.validation_ds.defer_setup=true
8 changes: 4 additions & 4 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,10 @@ def reduce(self, losses, target_lengths):

@typecheck()
def forward(self, log_probs, targets, input_lengths, target_lengths):
# Cast to int 32
targets = targets.int()
input_lengths = input_lengths.int()
target_lengths = target_lengths.int()
# Cast to int 64
targets = targets.long()
input_lengths = input_lengths.long()
target_lengths = target_lengths.long()

max_logit_len = input_lengths.max()
max_targets_len = target_lengths.max()
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def check_dim(var, dim, name):

def certify_inputs(log_probs, labels, lengths, label_lengths):
# check_type(log_probs, torch.float32, "log_probs")
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_type(labels, torch.int64, "labels")
check_type(label_lengths, torch.int64, "label_lengths")
check_type(lengths, torch.int64, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
Expand Down Expand Up @@ -357,8 +357,8 @@ def forward(self, acts, labels, act_lens, label_lens):
torch.manual_seed(0)

acts = torch.randn(1, 2, 5, 3)
labels = torch.tensor([[0, 2, 1, 2]], dtype=torch.int32)
act_lens = torch.tensor([2], dtype=torch.int32)
label_lens = torch.tensor([len(labels[0])], dtype=torch.int32)
labels = torch.tensor([[0, 2, 1, 2]], dtype=torch.int64)
act_lens = torch.tensor([2], dtype=torch.int64)
label_lens = torch.tensor([len(labels[0])], dtype=torch.int64)

loss_val = loss(acts, labels, act_lens, label_lens)
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ def check_dim(var, dim, name):

def certify_inputs(log_probs, labels, lengths, label_lengths):
# check_type(log_probs, torch.float32, "log_probs")
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_type(labels, torch.int64, "labels")
check_type(label_lengths, torch.int64, "label_lengths")
check_type(lengths, torch.int64, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
Expand Down
85 changes: 84 additions & 1 deletion nemo/collections/tts/helpers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def sort_tensor(


def unsort_tensor(ordered: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = 0) -> torch.Tensor:
unsort_ids = indices.gather(0, indices.argsort(0))
unsort_ids = indices.gather(0, indices.argsort(0, descending=True))
return torch.index_select(ordered, dim, unsort_ids)


Expand Down Expand Up @@ -706,3 +706,86 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor):
raise ValueError("Can only mask tensors of shape B x D x L and B x D1 x D2 x L")

return tensor * mask


@torch.jit.script
def batch_from_ragged(
text: torch.Tensor,
pitch: torch.Tensor,
pace: torch.Tensor,
batch_lengths: torch.Tensor,
padding_idx: int = -1,
volume: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

batch_lengths = batch_lengths.to(dtype=torch.int64)
max_len = torch.max(batch_lengths[1:] - batch_lengths[:-1])

index = 1
num_batches = batch_lengths.shape[0] - 1
texts = torch.zeros(num_batches, max_len, dtype=torch.int64, device=text.device) + padding_idx
pitches = torch.ones(num_batches, max_len, dtype=torch.float32, device=text.device)
paces = torch.zeros(num_batches, max_len, dtype=torch.float32, device=text.device) + 1.0
volumes = torch.zeros(num_batches, max_len, dtype=torch.float32, device=text.device) + 1.0
lens = torch.zeros(num_batches, dtype=torch.int64, device=text.device)
last_index = index - 1
while index < batch_lengths.shape[0]:
seq_start = batch_lengths[last_index]
seq_end = batch_lengths[index]
cur_seq_len = seq_end - seq_start
lens[last_index] = cur_seq_len
texts[last_index, :cur_seq_len] = text[seq_start:seq_end]
pitches[last_index, :cur_seq_len] = pitch[seq_start:seq_end]
paces[last_index, :cur_seq_len] = pace[seq_start:seq_end]
if volume is not None:
volumes[last_index, :cur_seq_len] = volume[seq_start:seq_end]
last_index = index
index += 1

return texts, pitches, paces, volumes, lens


def sample_tts_input(
export_config, device, max_batch=1, max_dim=127,
):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
sz = (max_batch * max_dim,) if export_config["enable_ragged_batches"] else (max_batch, max_dim)
inp = torch.randint(*export_config["emb_range"], sz, device=device, dtype=torch.int64)
pitch = torch.randn(sz, device=device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs = {'text': inp, 'pitch': pitch, 'pace': pace}
if export_config["enable_ragged_batches"]:
batch_lengths = torch.zeros((max_batch + 1), device=device, dtype=torch.int32)
left_over_size = sz[0]
batch_lengths[0] = 0
for i in range(1, max_batch):
equal_len = (left_over_size - (max_batch - i)) // (max_batch - i)
length = torch.randint(equal_len // 2, equal_len, (1,), device=device, dtype=torch.int32)
batch_lengths[i] = length + batch_lengths[i - 1]
left_over_size -= length.detach().cpu().numpy()[0]
batch_lengths[-1] = left_over_size + batch_lengths[-2]

sum = 0
index = 1
while index < len(batch_lengths):
sum += batch_lengths[index] - batch_lengths[index - 1]
index += 1
assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"
else:
batch_lengths = torch.randint(max_dim // 2, max_dim, (max_batch,), device=device, dtype=torch.int32)
batch_lengths[0] = max_dim
inputs['batch_lengths'] = batch_lengths

if export_config["enable_volume"]:
volume = torch.clamp(torch.randn(sz, device=device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs['volume'] = volume

if "num_speakers" in export_config:
inputs['speaker'] = torch.randint(
0, export_config["num_speakers"], (max_batch,), device=device, dtype=torch.int64
)
return inputs
90 changes: 18 additions & 72 deletions nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from pytorch_lightning.loggers import TensorBoardLogger

from nemo.collections.common.parts.preprocessing import parsers
from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy, plot_spectrogram_to_numpy, process_batch
from nemo.collections.tts.helpers.helpers import (
batch_from_ragged,
plot_alignment_to_numpy,
plot_spectrogram_to_numpy,
process_batch,
sample_tts_input,
)
from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss
from nemo.collections.tts.losses.fastpitchloss import DurationLoss, EnergyLoss, MelLoss, PitchLoss
from nemo.collections.tts.models.base import SpectrogramGenerator
Expand Down Expand Up @@ -156,7 +162,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
speaker_emb_condition_aligner,
)
self._input_types = self._output_types = None
self.export_config = {"enable_volume": False, "enable_ragged_batches": False}
self.export_config = {
"emb_range": (0, self.fastpitch.encoder.word_emb.num_embeddings),
"enable_volume": False,
"enable_ragged_batches": False,
}
if self.fastpitch.speaker_emb is not None:
self.export_config["num_speakers"] = cfg.n_speakers

def _get_default_text_tokenizer_conf(self):
text_tokenizer: TextTokenizerConfig = TextTokenizerConfig()
Expand Down Expand Up @@ -665,46 +677,14 @@ def input_example(self, max_batch=1, max_dim=44):
A tuple of input examples.
"""
par = next(self.fastpitch.parameters())
sz = (max_batch * max_dim,) if self.export_config["enable_ragged_batches"] else (max_batch, max_dim)
inp = torch.randint(
0, self.fastpitch.encoder.word_emb.num_embeddings, sz, device=par.device, dtype=torch.int64
)
pitch = torch.randn(sz, device=par.device, dtype=torch.float32) * 0.5
pace = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)

inputs = {'text': inp, 'pitch': pitch, 'pace': pace}

if self.export_config["enable_volume"]:
volume = torch.clamp(torch.randn(sz, device=par.device, dtype=torch.float32) * 0.1 + 1, min=0.01)
inputs['volume'] = volume
if self.export_config["enable_ragged_batches"]:
batch_lengths = torch.zeros((max_batch + 1), device=par.device, dtype=torch.int32)
left_over_size = sz[0]
batch_lengths[0] = 0
for i in range(1, max_batch):
length = torch.randint(1, left_over_size - (max_batch - i), (1,), device=par.device)
batch_lengths[i] = length + batch_lengths[i - 1]
left_over_size -= length.detach().cpu().numpy()[0]
batch_lengths[-1] = left_over_size + batch_lengths[-2]

sum = 0
index = 1
while index < len(batch_lengths):
sum += batch_lengths[index] - batch_lengths[index - 1]
index += 1
assert sum == sz[0], f"sum: {sum}, sz: {sz[0]}, lengths:{batch_lengths}"
inputs['batch_lengths'] = batch_lengths

if self.fastpitch.speaker_emb is not None:
inputs['speaker'] = torch.randint(
0, self.fastpitch.speaker_emb.num_embeddings, (max_batch,), device=par.device, dtype=torch.int64
)

inputs = sample_tts_input(self.export_config, par.device, max_batch=max_batch, max_dim=max_dim)
if 'enable_ragged_batches' not in self.export_config:
inputs.pop('batch_lengths', None)
return (inputs,)

def forward_for_export(self, text, pitch, pace, volume=None, batch_lengths=None, speaker=None):
if self.export_config["enable_ragged_batches"]:
text, pitch, pace, volume_tensor = create_batch(
text, pitch, pace, volume_tensor, lens = batch_from_ragged(
text, pitch, pace, batch_lengths, padding_idx=self.fastpitch.encoder.padding_idx, volume=volume
)
if volume is not None:
Expand Down Expand Up @@ -743,37 +723,3 @@ def interpolate_speaker(
)
new_speaker_emb = weight_speaker_1 * speaker_emb_1 + weight_speaker_2 * speaker_emb_2
self.fastpitch.speaker_emb.weight.data[new_speaker_id] = new_speaker_emb


@torch.jit.script
def create_batch(
text: torch.Tensor,
pitch: torch.Tensor,
pace: torch.Tensor,
batch_lengths: torch.Tensor,
padding_idx: int = -1,
volume: Optional[torch.Tensor] = None,
):
batch_lengths = batch_lengths.to(torch.int64)
max_len = torch.max(batch_lengths[1:] - batch_lengths[:-1])

index = 1
texts = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.int64, device=text.device) + padding_idx
pitches = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device)
paces = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0
volumes = torch.zeros(batch_lengths.shape[0] - 1, max_len, dtype=torch.float32, device=text.device) + 1.0

while index < batch_lengths.shape[0]:
seq_start = batch_lengths[index - 1]
seq_end = batch_lengths[index]
cur_seq_len = seq_end - seq_start

texts[index - 1, :cur_seq_len] = text[seq_start:seq_end]
pitches[index - 1, :cur_seq_len] = pitch[seq_start:seq_end]
paces[index - 1, :cur_seq_len] = pace[seq_start:seq_end]
if volume is not None:
volumes[index - 1, :cur_seq_len] = volume[seq_start:seq_end]

index += 1

return texts, pitches, paces, volumes
Loading

0 comments on commit 821dd11

Please sign in to comment.