Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASR-TTS Models: Support hybrid RNNT-CTC, improve docs. #6620

Merged
11 changes: 6 additions & 5 deletions docs/source/asr/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -885,17 +885,17 @@ Hybrid ASR-TTS Model Configuration

:ref:`Hybrid ASR-TTS model <Hybrid-ASR-TTS_model>` consists of three parts:

* ASR model (``EncDecCTCModelBPE`` or ``EncDecRNNTBPEModel``)
* ASR model (``EncDecCTCModelBPE``, ``EncDecRNNTBPEModel`` or ``EncDecHybridRNNTCTCBPEModel``)
* TTS Mel Spectrogram Generator (currently, only :ref:`FastPitch <FastPitch_model>` model is supported)
* Enhancer model (optional)
* :ref:`Enhancer model <SpectrogramEnhancer_model>` (optional)

Also, the config allows to specify :ref:`text-only dataset <Hybrid-ASR-TTS_model__Text-Only-Data>`.

Main parts of the config:

* ASR model
* ``asr_model_path``: path to the ASR model checkpoint (`.nemo`) file, loaded only once, then the config of the ASR model is stored in the ``asr_model`` field
* ``asr_model_type``: needed only when training from scratch, ``rnnt_bpe`` corresponds to ``EncDecRNNTBPEModel``, ``ctc_bpe`` to ``EncDecCTCModelBPE``
* ``asr_model_type``: needed only when training from scratch. ``rnnt_bpe`` corresponds to ``EncDecRNNTBPEModel``, ``ctc_bpe`` to ``EncDecCTCModelBPE``, ``hybrid_rnnt_ctc_bpe`` to ``EncDecHybridRNNTCTCBPEModel``
* ``asr_model_fuse_bn``: fusing BatchNorm in the pretrained ASR model, can improve quality in finetuning scenario
* TTS model
* ``tts_model_path``: path to the pretrained TTS model checkpoint (`.nemo`) file, loaded only once, then the config of the model is stored in the ``tts_model`` field
Expand All @@ -907,7 +907,7 @@ Main parts of the config:
* ``speakers_filepath``: path (or paths) to the text file containing speaker ids for the multi-speaker TTS model (speakers are sampled randomly during training)
* ``min_words`` and ``max_words``: parameters to filter text-only manifests by the number of words
* ``tokenizer_workers``: number of workers for initial tokenization (when loading the data). ``num_CPUs / num_GPUs`` is a recommended value.
* ``asr_tts_sampling_technique``, ``asr_tts_sampling_temperature``, ``asr_tts_sampling_probabilities``: sampling parameters for text-only and audio-text data (if both specified). See parameters for ``nemo.collections.common.data.ConcatDataset``
* ``asr_tts_sampling_technique``, ``asr_tts_sampling_temperature``, ``asr_tts_sampling_probabilities``: sampling parameters for text-only and audio-text data (if both specified). Correspond to ``sampling_technique``, ``sampling_temperature``, and ``sampling_probabilities`` parameters of the :mod:`ConcatDataset <nemo.collections.common.data.dataset.ConcatDataset>`.
* all other components are similar to conventional ASR models
* ``validation_ds`` and ``test_ds`` correspond to the underlying ASR model

Expand All @@ -920,7 +920,7 @@ Main parts of the config:
# asr model
asr_model_path: ???
asr_model: null
asr_model_type: null # rnnt_bpe or ctc_bpe, needed only if instantiating from config, otherwise type is auto inferred
asr_model_type: null # rnnt_bpe, ctc_bpe or hybrid_rnnt_ctc_bpe; needed only if instantiating from config, otherwise type is auto inferred
asr_model_fuse_bn: false # only ConformerEncoder supported now, use false for other models

# tts model
Expand Down Expand Up @@ -972,6 +972,7 @@ Training from Scratch
To train ASR model from scratch using text-only data use ``<NeMo_git_root>/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py`` script with conventional ASR model config, e.g. ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_ctc_bpe.yaml`` or ``<NeMo_git_root>/examples/asr/conf/conformer/conformer_transducer_bpe.yaml``

Please specify the ASR model type, paths to the TTS model, and (optional) enhancer, along with text-only data-related fields.
Use ``++`` or ``+`` markers for these options, since the options are not present in the original ASR model config.

.. code-block:: shell

Expand Down
2 changes: 1 addition & 1 deletion docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ The model consists of three models:

* ASR model (``EncDecCTCModelBPE`` or ``EncDecRNNTBPEModel``)
* Frozen TTS Mel Spectrogram Generator (currently, only :ref:`FastPitch <FastPitch_model>` model is supported)
* Optional frozen Enhancer model trained to mitigate mismatch between real and generated mel spectrogram
* Optional frozen :ref:`Spectrogram Enhancer model <SpectrogramEnhancer_model>` model trained to mitigate mismatch between real and generated mel spectrogram

.. image:: images/hybrid_asr_tts_model.png
:align: center
Expand Down
13 changes: 13 additions & 0 deletions docs/source/common/data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Data
----

.. autoclass:: nemo.collections.common.data.dataset.ConcatDataset
:show-inheritance:
:members:
:undoc-members:


.. autoclass:: nemo.collections.common.data.dataset.ConcatMapDataset
:show-inheritance:
:members:
:undoc-members:
1 change: 1 addition & 0 deletions docs/source/common/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ The common collection contains things that could be used across all collections.
losses
metrics
tokenizers
data
5 changes: 5 additions & 0 deletions docs/source/tts/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ Mel-Spectrogram Generators
:members:
:exclude-members: setup_training_data, setup_validation_data, training_step, validation_epoch_end, validation_step, setup_test_data, on_train_epoch_start

.. autoclass:: nemo.collections.tts.models.SpectrogramEnhancerModel
:show-inheritance:
:members:
:exclude-members: setup_training_data, setup_validation_data, training_step, validation_epoch_end, validation_step, setup_test_data, on_train_epoch_start


Speech-to-Text Aligner Models
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
13 changes: 12 additions & 1 deletion docs/source/tts/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ Speech-to-text alignment is a critical component of neural TTS models. Autoregre


End2End Models
--------
--------------

VITS
~~~~~~~~~~~~~~~
Expand All @@ -123,6 +123,17 @@ VITS is an end-to-end speech synthesis model, which generates raw waveform audio
:alt: vits model
:scale: 25%


Enhancers
---------

.. _SpectrogramEnhancer_model:

Spectrogram Enhancer
~~~~~~~~~~~~~~~~~~~~
GAN-based model to add details to blurry spectrograms from TTS models like Tacotron or FastPitch.


References
----------

Expand Down
2 changes: 1 addition & 1 deletion examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
```shell
python speech_to_text_bpe_with_text.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
++asr_model_type=<rnnt_bpe or ctc_bpe> \
++asr_model_type=<rnnt_bpe, ctc_bpe or hybrid_rnnt_ctc_bpe> \
++tts_model_path=<path to compatible tts model> \
++enhancer_model_path=<optional path to enhancer model> \
model.tokenizer.dir=<path to tokenizer> \
Expand Down
2 changes: 1 addition & 1 deletion examples/asr/conf/asr_tts/hybrid_asr_tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ model:
# asr model
asr_model_path: ???
asr_model: null
asr_model_type: null # rnnt_bpe or ctc_bpe, needed only if instantiating from config, otherwise type is auto inferred
asr_model_type: null # rnnt_bpe, ctc_bpe or hybrid_rnnt_ctc_bpe; needed only if instantiating from config, otherwise type is auto inferred
asr_model_fuse_bn: false # only ConformerEncoder supported now, use false for other models

# tts model
Expand Down
6 changes: 2 additions & 4 deletions nemo/collections/asr/data/text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.core.classes import Dataset, IterableDataset
from nemo.utils import logging
from nemo.utils.import_guards import optional_import_guard

try:
with optional_import_guard():
from nemo_text_processing.text_normalization.normalize import Normalizer
except Exception as e:
logging.warning(e)
logging.warning("nemo_text_processing is not installed")

AnyPath = Union[Path, str]

Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/models/hybrid_asr_tts_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder
from nemo.collections.asr.parts.preprocessing.features import clean_spectrogram_batch, normalize_batch
Expand Down Expand Up @@ -89,7 +90,7 @@ class ASRWithTTSModel(ASRModel):
Text-only data can be mixed with audio-text pairs
"""

asr_model: Union[EncDecRNNTBPEModel, EncDecCTCModelBPE]
asr_model: Union[EncDecRNNTBPEModel, EncDecCTCModelBPE, EncDecHybridRNNTCTCBPEModel]
tts_model: FastPitchModel
enhancer_model: Optional[SpectrogramEnhancerModel]

Expand All @@ -100,20 +101,25 @@ class ASRModelTypes(PrettyStrEnum):

RNNT_BPE = "rnnt_bpe"
CTC_BPE = "ctc_bpe"
HYBRID_RNNT_CTC_BPE = "hybrid_rnnt_ctc_bpe"

@classmethod
def from_asr_model(cls, model: Any):
if isinstance(model, EncDecRNNTBPEModel):
return cls.RNNT_BPE
if isinstance(model, EncDecCTCModelBPE):
return cls.CTC_BPE
if isinstance(model, EncDecHybridRNNTCTCBPEModel):
return cls.HYBRID_RNNT_CTC_BPE
raise ValueError(f"Unsupported model type: {type(model)}")

def get_asr_cls(self):
if self == self.RNNT_BPE:
return EncDecRNNTBPEModel
if self == self.CTC_BPE:
return EncDecCTCModelBPE
if self == self.HYBRID_RNNT_CTC_BPE:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
return EncDecHybridRNNTCTCBPEModel
raise NotImplementedError(f"Not implemented for value {self.value}")

@classmethod
Expand Down Expand Up @@ -540,7 +546,7 @@ def _setup_text_dataset_from_config(
manifest_filepath=text_data_config.manifest_filepath,
speakers_filepath=text_data_config.speakers_filepath,
asr_tokenizer=self.asr_model.tokenizer,
asr_use_start_end_token=train_data_config.use_start_end_token,
asr_use_start_end_token=train_data_config.get("use_start_end_token", False),
tts_parser=self.tts_model.parser,
tts_text_pad_id=self.tts_model.vocab.pad,
tts_text_normalizer=self.tts_model.normalizer,
Expand All @@ -556,7 +562,7 @@ def _setup_text_dataset_from_config(
manifest_filepath=text_data_config.manifest_filepath,
speakers_filepath=text_data_config.speakers_filepath,
asr_tokenizer=self.asr_model.tokenizer,
asr_use_start_end_token=train_data_config.use_start_end_token,
asr_use_start_end_token=train_data_config.get("use_start_end_token", False),
tts_parser=self.tts_model.parser,
tts_text_pad_id=self.tts_model.vocab.pad,
tts_text_normalizer=self.tts_model.normalizer,
Expand Down
38 changes: 38 additions & 0 deletions nemo/utils/import_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager

from nemo.utils import logging


@contextmanager
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you check if import was successful or not ? We use a bool flag like HAS_APEX or NUMBA to be able to determine if the module could be imported or not

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I just wanted to suppress missing import. I use it only for annotation purposes ("Normalizer" is not a valid annotation if Normalizer is not imported!), but I don't care if it is missing (TTS model will fail with the appropriate message).
This doesn't sacrifice usability, but still will work in all cases.

If I need some library to be explicitly used, I thought about something like @k2_required decorators, which will allow to import everything-everywhere, but will raise an error if the method is called/ class is instantiated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import_guard is removed

def optional_import_guard(warn_on_error=False):
"""
Context manager to wrap optional import.
Suppresses `ImportError` (also, `ModuleNotFoundError`), adds warning if `warn_on_error` is True.
Use separately for each library.

>>> with optional_import_guard():
... import optional_library

:param warn_on_error: log warning if import resulted in error
"""
try:
yield
except ImportError as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except (ImportError, ModuleNotFoundError):

Copy link
Collaborator Author

@artbataev artbataev May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? ModuleNotFoundError is a subclass of ImportError.
So, it will be catched:

>>> issubclass(ModuleNotFoundError, ImportError)
True

>>> try:
...     raise ModuleNotFoundError
... except ImportError:
...     print("Catched")
Catched

Copy link
Collaborator Author

@artbataev artbataev May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import_guard is removed

if warn_on_error:
logging.warning(e)
finally:
pass
Loading