Skip to content

Commit

Permalink
NeMo + Lhotse integration (#7880)
Browse files Browse the repository at this point in the history
* Lhotse integration squashed PR

Signed-off-by: Piotr Żelasko <[email protected]>

* Code review - Som

Signed-off-by: Piotr Żelasko <[email protected]>

* Update copyright headers to 2024

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix NLP imports

Signed-off-by: Piotr Żelasko <[email protected]>

* Code review - Vahid

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>
Signed-off-by: Pablo Garay <[email protected]>
  • Loading branch information
pzelasko authored and pablo-garay committed Mar 19, 2024
1 parent f975b9c commit 8c7bed2
Show file tree
Hide file tree
Showing 16 changed files with 1,696 additions and 2 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ RUN INSTALL_MSG=$(/bin/bash /tmp/torchaudio_build/scripts/installers/install_tor

# install nemo dependencies
WORKDIR /tmp/nemo
ENV LHOTSE_REQUIRE_TORCHAUDIO=0
COPY requirements .
RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done

Expand Down
123 changes: 123 additions & 0 deletions docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,129 @@ An example using an AIS cluster at ``hostname:port`` with a tarred dataset for t
.. _Hybrid-ASR-TTS_model__Text-Only-Data:


Lhotse Dataloading
------------------

NeMo supports using `Lhotse`_, a speech data handling library, as a dataloading option. The key features of Lhotse used in NeMo are:

* Dynamic batch sizes
Lhotse samples mini-batches to satisfy the constraint of total speech duration in a mini-batch (``batch_duration``),
rather than a specific number of examples (i.e., batch size).
* Dynamic bucketing
Instead of statically pre-bucketing the data, Lhotse allocates training examples to buckets dynamically.
This allows more rapid experimentation with bucketing settings (number of buckets, specific placement of bucket duration bins)
to minimize the amount of padding and accelerate training.
* Quadratic duration penalty
Adding a quadratic penalty to an utterance's duration allows to sample mini-batches so that the
GPU utilization is more consistent across big batches of short utterances and small batches of long utterances when using
models with quadratic time/memory complexity (such as transformer).
* Dynamic weighted data source multiplexing
An approach to combining diverse data sources (e.g. multiple domains, languages, tasks)
where each data source is treated as a separate stream with its own sampling probability. The resulting data stream is a
multiplexer that samples from each sub-stream. This approach ensures that the distribution of different sources is approximately
constant in time (i.e., stationary); in fact, each mini-batch will have roughly the same ratio of data coming from each source.
Since the multiplexing is done dynamically, it is very easy to tune the sampling weights.

Lhotse dataloading supports the following types of inputs:

* NeMo manifests
Regular NeMo JSON manifests.
* NeMo tarred data
Tarred NeMo JSON manifests + audio tar files; we also support combination of multiple NeMo
tarred data sources (e.g., multiple buckets of NeMo data or multiple datasets) via dynamic multiplexing.
* Lhotse CutSet manifests
Regular Lhotse CutSet manifests (typically gzipped JSONL).
See `Lhotse Cuts documentation`_ to learn more about Lhotse data formats.
* Lhotse Shar data
Lhotse Shar is a data format that also uses tar files for sequential data loading,
but is designed to be modular (i.e., easily extensible with new data sources and with new feature fields).
More details can be found here: |tutorial_shar|

.. caution:: As of now, Lhotse is mainly supported in most ASR model configurations. We aim to gradually extend this support to other speech tasks.

.. _Lhotse: https://github.com/lhotse-speech/lhotse
.. _Lhotse Cuts documentation: https://lhotse.readthedocs.io/en/latest/cuts.html
.. |tutorial_shar| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb

Enabling Lhotse via configuration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. note:: Using Lhotse with tarred datasets will make the dataloader infinite, ditching the notion of an "epoch". "Epoch" may still be logged in W&B/TensorBoard, but it will correspond to the number of executed training loops between validation loops.

Start with an existing NeMo experiment YAML configuration. Typically, you'll only need to add a few options to enable Lhotse.
These options are::

# NeMo generic dataloading arguments
model.train_ds.manifest_filepath=...
model.train_ds.tarred_audio_filepaths=... # for tarred datasets only
model.train_ds.num_workers=4
model.train_ds.min_duration=0.3 # optional
model.train_ds.max_duration=30.0 # optional
model.train_ds.shuffle=true # optional

# Lhotse dataloading related arguments
++model.train_ds.use_lhotse=True
++model.train_ds.batch_duration=1100
++model.train_ds.quadratic_duration=30
++model.train_ds.num_buckets=30
++model.train_ds.num_cuts_for_bins_estimate=10000
++model.train_ds.bucket_buffer_size=10000
++model.train_ds.shuffle_buffer_size=10000

# PyTorch Lightning related arguments
++trainer.use_distributed_sampler=false
++trainer.limit_train_batches=1000
trainer.val_check_interval=1000
trainer.max_steps=300000

.. note:: The default values above are a reasonable starting point for a hybrid RNN-T + CTC ASR model on a 32GB GPU with a data distribution dominated by 15s long utterances.

Let's briefly go over each of the Lhotse dataloading arguments:

* ``use_lhotse`` enables Lhotse dataloading
* ``batch_duration`` is the total max duration of utterances in a mini-batch and controls the batch size; the more shorter utterances, the bigger the batch size, and vice versa.
* ``quadratic_duration`` adds a quadratically growing penalty for long utterances; useful in bucketing and transformer type of models. The value set here means utterances this long will count as if with a doubled duration.
* ``num_buckets`` is the number of buckets in the bucketing sampler. Bigger value means less padding but also less randomization.
* ``num_cuts_for_bins_estimate`` is the number of utterance we will sample before the start of the training to estimate the duration bins for buckets. Larger number results in a more accurate estimatation but also a bigger lag before starting the training.
* ``bucket_buffer_size`` is the number of utterances (data and metadata) we will hold in memory to be distributed between buckets. With bigger ``batch_duration``, this number may need to be increased for dynamic bucketing sampler to work properly (typically it will emit a warning if this is too low).
* ``shuffle_buffer_size`` is an extra number of utterances we will hold in memory to perform approximate shuffling (via reservoir-like sampling). Bigger number means more memory usage but also better randomness.

The PyTorch Lightning ``trainer`` related arguments:

* ``use_distributed_sampler=false`` is required because Lhotse has its own handling of distributed sampling.
* ``val_check_interval``/``limit_train_batches``
These are required for dataloaders with tarred/Shar datasets
because Lhotse makes the dataloader infinite, so we'd never go past epoch 0. This approach guarantees
we will never hang the training because the dataloader in some node has less mini-batches than the others
in some epochs. The value provided here will be the effective length of each "pseudo-epoch" after which we'll
trigger the validation loop.
* ``max_steps`` is the total number of steps we expect to be training for. It is required for the same reason as ``limit_train_batches``; since we'd never go past epoch 0, the training would have never finished.

Some other Lhotse related arguments we support:

* ``cuts_path`` can be provided to read data from a Lhotse CutSet manifest instead of a NeMo manifest.
Specifying this option will result in ``manifest_filepaths`` and ``tarred_audio_filepaths`` being ignored.
* ``shar_path``
Can be provided to read data from a Lhotse Shar manifest instead of a NeMo manifest.
This argument can be a string (single Shar directory), a list of strings (Shar directories),
or a list of 2-item lists, where the first item is a Shar directory path, and the other is a sampling weight.
Specifying this option will result in ``manifest_filepaths`` and ``tarred_audio_filepaths`` being ignored.
* ``bucket_duration_bins``
Duration bins are a list of float values (seconds) that when provided, will skip the initial bucket bin estimation
and save some time. It has to have a length of ``num_buckets - 1``. An optimal value can be obtained by running CLI:
``lhotse cut estimate-bucket-bins -b $num_buckets my-cuts.jsonl.gz``
* ``use_bucketing`` is a boolean which indicates if we want to enable/disable dynamic bucketing. By defalt it's enabled.
* ``text_field`` is the name of the key in the JSON (NeMo) manifest from which we should be reading text (default="text").
* ``lang_field`` is the name of the key in the JSON (NeMo) manifest from which we should be reading language tag (default="lang"). This is useful when working e.g. with ``AggregateTokenizer``.
* ``batch_size``
Limits the number of examples in a mini-batch to this number, when combined with ``batch_duration``.
When ``batch_duration`` is not set, it acts as a static batch size.
* ``seed`` sets a random seed for the shuffle buffer.

The full and always up-to-date list of supported options can be found in ``LhotseDataLoadingConfig`` class.

Preparing Text-Only Data for Hybrid ASR-TTS Models
--------------------------------------------------

Expand Down
84 changes: 84 additions & 0 deletions nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional, Tuple

import torch.utils.data
from lhotse.dataset import AudioSamples
from lhotse.dataset.collation import collate_vectors

from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType


class LhotseSpeechToTextBpeDataset(torch.utils.data.Dataset):
"""
This dataset is based on BPE datasets from audio_to_text.py.
Unlike native NeMo datasets, Lhotse dataset defines only the mapping from
a CutSet (meta-data) to a mini-batch with PyTorch tensors.
Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any).
Managing data, sampling, de-duplication across workers/nodes etc. is all handled
by Lhotse samplers instead.
"""

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {
'audio_signal': NeuralType(('B', 'T'), AudioSignal()),
'a_sig_length': NeuralType(tuple('B'), LengthsType()),
'transcripts': NeuralType(('B', 'T'), LabelsType()),
'transcript_length': NeuralType(tuple('B'), LengthsType()),
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}

def __init__(self, tokenizer):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)

def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)
tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=0)
return audio, audio_lens, tokens, token_lens


class TokenizerWrapper:
"""
Provide a unified interface for NeMo Tokenizer, AggregateTokenizer, and (char) Parser.
"""

def __init__(self, tokenizer):
self._tokenizer = tokenizer
if isinstance(tokenizer, AggregateTokenizer):
self._impl = self._call_agg_tokenizer
elif isinstance(tokenizer, TokenizerSpec):
self._impl = self._call_tokenizer
else:
self._impl = self._call_parser

def __call__(self, text: str, lang: str | None = None):
return self._impl(text, lang)

def _call_agg_tokenizer(self, text: str, lang: str | None = None):
assert lang is not None, "Expected 'lang' to be set for AggregateTokenizer."
return self._tokenizer.text_to_ids(text, lang)

def _call_tokenizer(self, text: str, lang: str | None = None):
return self._tokenizer.text_to_ids(text)

def _call_parser(self, text: str, lang: str | None = None):
return self._tokenizer(text)
10 changes: 10 additions & 0 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging, model_utils

Expand Down Expand Up @@ -90,6 +92,14 @@ def __init__(self, cfg: DictConfig, trainer=None):
)

def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer),
)

dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config(
config=config,
local_rank=self.local_rank,
Expand Down
20 changes: 20 additions & 0 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRModuleMixin, InterCTCMixin
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.parts.preprocessing.parsers import make_parser
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.classes.mixins import AccessMixin
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType
Expand Down Expand Up @@ -350,6 +353,23 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
# Automatically inject args from model config to dataloader config
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')

if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=make_parser(
labels=config.get('labels', None),
name=config.get('parser', 'en'),
unk_id=config.get('unk_index', -1),
blank_id=config.get('blank_index', -1),
do_normalize=config.get('normalize_transcripts', False),
),
),
)

dataset = audio_to_text_dataset.get_audio_to_text_char_dataset_from_config(
config=config,
local_rank=self.local_rank,
Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.losses.rnnt import RNNTLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding, RNNTBPEDecodingConfig
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging, model_utils

Expand Down Expand Up @@ -128,6 +130,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.cur_decoder = "rnnt"

def _setup_dataloader_from_config(self, config: Optional[Dict]):

if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer,),
)

dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config(
config=config,
local_rank=self.local_rank,
Expand Down
10 changes: 10 additions & 0 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.losses.rnnt import RNNTLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin
from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding, RNNTBPEDecodingConfig
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging, model_utils

Expand Down Expand Up @@ -485,6 +487,14 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):
logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer,),
)

dataset = audio_to_text_dataset.get_audio_to_text_bpe_dataset_from_config(
config=config,
local_rank=self.local_rank,
Expand Down
Loading

0 comments on commit 8c7bed2

Please sign in to comment.