From 8d1c01846a8899fca57e8392b8f8eb1126fa50e5 Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Fri, 30 Jun 2023 13:04:46 -0700 Subject: [PATCH 1/6] Add ASR with TTS Tutorial Signed-off-by: Vladimir Bataev --- tutorials/asr/ASR_TTS_Tutorial.ipynb | 797 +++++++++++++++++++++++++++ 1 file changed, 797 insertions(+) create mode 100644 tutorials/asr/ASR_TTS_Tutorial.ipynb diff --git a/tutorials/asr/ASR_TTS_Tutorial.ipynb b/tutorials/asr/ASR_TTS_Tutorial.ipynb new file mode 100644 index 000000000000..567c7df19088 --- /dev/null +++ b/tutorials/asr/ASR_TTS_Tutorial.ipynb @@ -0,0 +1,797 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a3570803-9bfa-4e97-9891-5ae0759eb8ca", + "metadata": {}, + "source": [ + "# Hybrid ASR-TTS models tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "50fc294f-f319-4465-8f90-a28b49843e60", + "metadata": {}, + "source": [ + "This tutorial is designed to introduce using Hybrid ASR-TTS models, aka `ASRWithTTSModel`, to finetune existing ASR models using an integrated text-to-mel-spectrogram generator. " + ] + }, + { + "cell_type": "markdown", + "id": "d2a01ca5-bd48-4d82-a97d-5b07a7b27ca0", + "metadata": {}, + "source": [ + "## ASR-TTS Models: Description" + ] + }, + { + "cell_type": "markdown", + "id": "b32467a9-c458-4590-aff7-e8d1e91b0870", + "metadata": {}, + "source": [ + "### Problem\n", + "\n", + "Adapting ASR models to a new text domain is a challenging task. Modern end-to-end systems can require several hundreds and thousands of hours to perform recognition with high accuracy. Acquiring audio-text paired data for a specific domain can be prohibitively expensive. Text-only data, on the other side, is widely available. \n", + "\n", + "One of the approaches for efficient adaptation is synthesizing audio data from text and using such data for training the ASR model conventionally. We modify this approach, incorporating TTS and ASR systems into a single model. We use only a lightweight multi-speaker text-to-mel-spectrogram generator (without vocoder) with an optional enhancer that mitigates the mismatch between natural and synthetic spectrograms.\n", + "\n", + "### Architecture\n", + "\n", + "\"ASR-TTS\n", + "\n", + "`ASRWithTTSModel` is a transparent wrapper for three models:\n", + "- ASR model (`EncDecCTCModelBPE`, `EncDecRNNTBPEModel` or `EncDecHybridRNNTCTCBPEModel` are supported)\n", + "- frozen text-to-mel-spectrogram model (currently, only `FastPitch` model is supported)\n", + "- optional frozen enhancer model\n", + "\n", + "The architecture is shown in the figure. \n", + "\n", + "The model can take text or audio as input during training. In the case of audio input, a mel spectrogram is extracted as usual and passed to ASR neural network. In the case of textual input, the mel spectrogram generator produces a spectrogram on the fly from the text. The spectrogram is improved by the enhancer (if present) and fed into the ASR model. \n", + "\n", + "### Capabilities and limitations\n", + "\n", + "This approach can be used to finetune the pretrained ASR model using text-only data. Training new models from scratch is also possible. The text should contain phrases and sentences and be split into sentences (~45 words maximum, corresponding to ~16.7 seconds of synthesized audio). Using only separate words is not recommended since this doesn't allow to adapt ASR model adapts to recognize new words in context. \n", + "\n", + "Mixing audio-text pairs with text-only data from the original domain is recommended to preserve performance on the original data. \n", + "Also, fusing BatchNorm (see parameters below) is recommended for the best performance when using a large proportion of text compared to the amount of audio-text pairs in finetuning process.\n", + "\n", + "\n", + "### Implementation Details and Experiments\n", + "\n", + "Further details about implementation and experiments can be found in the paper [Text-only domain adaptation for end-to-end ASR using integrated text-to-mel-spectrogram generator](https://arxiv.org/abs/2302.14036)\n" + ] + }, + { + "cell_type": "markdown", + "id": "2702d081-c675-4a96-8263-6059e310d048", + "metadata": {}, + "source": [ + "## Example: Finetuning ASR model using text-only data" + ] + }, + { + "cell_type": "markdown", + "id": "30fe41a3-f36c-4803-a7f0-4260fb111478", + "metadata": {}, + "source": [ + "In this example, we will finetune a pretrained small Conformer-CTC model using text-only data from the AN4 dataset. [AN4 dataset](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/datasets.html#an4-dataset) is a small dataset that consists of sentences of people spelling out addresses, names, and other entities.\n", + "\n", + "The model is pretrained on LibriSpeech data and performs poorly on AN4 data (`~17.7%` WER on test data).\n", + "We will use only text from the train part to construct text-only training data for our model and will achieve a good performance on the test part of the AN4 dataset (`~2%` WER)." + ] + }, + { + "cell_type": "markdown", + "id": "923819bb-7822-412a-8f9b-98c76c70e0bb", + "metadata": {}, + "source": [ + "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run the following cell to set up dependencies.\n", + "\n", + "NOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use." + ] + }, + { + "cell_type": "markdown", + "id": "4685a9da-b3f8-4b95-ba74-64a114223233", + "metadata": {}, + "source": [ + "### Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d22d241-6c46-492c-99db-3bd69777243c", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + "except (ImportError, ModuleNotFoundError):\n", + " IN_COLAB = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc38a961-8822-4685-89ae-ab6f591f9c28", + "metadata": {}, + "outputs": [], + "source": [ + "BRANCH = 'main'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd60b1c4-7b1d-421d-9d63-95d7458bbcbd", + "metadata": {}, + "outputs": [], + "source": [ + "# If you're using Google Colab and not running locally, run this cell.\n", + "\n", + "if IN_COLAB:\n", + " ## Install dependencies\n", + " !pip install wget\n", + " !apt-get install sox libsndfile1 ffmpeg\n", + " !pip install text-unidecode\n", + "\n", + " ## Install NeMo\n", + " !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]" + ] + }, + { + "cell_type": "markdown", + "id": "08f99618-6f83-44b3-bc8e-f7df04fc471c", + "metadata": {}, + "source": [ + "### Import necessary libraries and utils" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74f780b1-9b72-4acf-bcf0-64e1ce84e76d", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "import string\n", + "import tempfile\n", + "\n", + "from omegaconf import OmegaConf\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "from tqdm.auto import tqdm\n", + "import wget\n", + "\n", + "from nemo.collections.asr.models import EncDecCTCModelBPE\n", + "from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel\n", + "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest\n", + "from nemo.collections.tts.models import FastPitchModel, SpectrogramEnhancerModel\n", + "from nemo.utils.notebook_utils import download_an4\n", + "\n", + "from nemo_text_processing.text_normalization.normalize import Normalizer" + ] + }, + { + "cell_type": "markdown", + "id": "ca928d36-fb0d-439b-bac0-299e98a72d02", + "metadata": {}, + "source": [ + "### Prepare Data" + ] + }, + { + "cell_type": "markdown", + "id": "702e8e92-17b2-4f34-a2d9-c72b94501bf5", + "metadata": {}, + "source": [ + "Download and preprocess AN4 data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62c7cfec-aa98-4fc5-8b31-23ee1d59f311", + "metadata": {}, + "outputs": [], + "source": [ + "DATASETS_DIR = Path(\"./datasets\") # directory for data\n", + "CHECKPOINTS_DIR = Path(\"./checkpoints/\") # directory for checkpoints" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "659db73e-dcd7-455c-8140-20e104d6ac00", + "metadata": {}, + "outputs": [], + "source": [ + "# create directories if necessary\n", + "DATASETS_DIR.mkdir(parents=True, exist_ok=True)\n", + "CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36830e7f-5293-4401-8c56-780127b47385", + "metadata": {}, + "outputs": [], + "source": [ + "download_an4(data_dir=f\"{DATASETS_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e77f5062-9acb-4f39-b811-a5b11dd6f76f", + "metadata": {}, + "outputs": [], + "source": [ + "AN4_DATASET = DATASETS_DIR / \"an4\"" + ] + }, + { + "cell_type": "markdown", + "id": "403b63b0-8aab-43aa-a455-31f588d1772f", + "metadata": {}, + "source": [ + "### Construct text-only training data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35654ee1-3869-4289-bd52-15818c0ccf69", + "metadata": {}, + "outputs": [], + "source": [ + "# read original training data\n", + "an4_train_data = read_manifest(AN4_DATASET / \"train_manifest.json\")" + ] + }, + { + "cell_type": "markdown", + "id": "a17f583c-2a5c-4faf-84bd-eb04c2921e01", + "metadata": {}, + "source": [ + "Text-only manifest should contain three fields:\n", + "- `text`: target text for ASR model\n", + "- `tts_text`: text to use as a source for TTS model (unnormalized)\n", + "- `tts_text_normalized`: text to use as a source for TTS model (unnormalized)\n", + "\n", + "If `tts_text_normalized` is not present, `tts_text` will be used, normalization will be done when loading the dataset.\n", + "It is highly recommended to normalize the text and create `tts_text_normalized` field manually, since current normalizers are unsuitable for processing a large amount of text on the fly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5938a8c2-e239-4a45-a716-dc11a981aec7", + "metadata": {}, + "outputs": [], + "source": [ + "# fill `text` and `tts_text` fields with the source data\n", + "textonly_data = []\n", + "for record in an4_train_data:\n", + " text = record[\"text\"]\n", + " textonly_data.append({\"text\": text, \"tts_text\": text})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f6a5735-a5c2-4a8b-8116-bfc535a2c299", + "metadata": {}, + "outputs": [], + "source": [ + "WHITELIST_URL = (\n", + " \"https://raw.githubusercontent.com/NVIDIA/NeMo-text-processing/main/\"\n", + " \"nemo_text_processing/text_normalization/en/data/whitelist/lj_speech.tsv\"\n", + ")\n", + "\n", + "\n", + "def get_normalizer() -> Normalizer:\n", + " with tempfile.TemporaryDirectory() as data_dir:\n", + " whitelist_path = Path(data_dir) / \"lj_speech.tsv\"\n", + " if not whitelist_path.exists():\n", + " wget.download(WHITELIST_URL, out=str(data_dir))\n", + "\n", + " normalizer = Normalizer(\n", + " lang=\"en\",\n", + " input_case=\"cased\",\n", + " whitelist=str(whitelist_path),\n", + " overwrite_cache=True,\n", + " cache_dir=None,\n", + " )\n", + " return normalizer" + ] + }, + { + "cell_type": "markdown", + "id": "dd0253aa-d7f1-47ee-a142-099b71241270", + "metadata": {}, + "source": [ + "Сonstruct `tts_text_normalized` field by applying English normalizer to the text.\n", + "\n", + "AN4 data doesn't contain numbers, currency and other entities, so, normalizer is used here only for demonstration purpose." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27bb29d5-d44d-4026-98f8-5f0b1241b39a", + "metadata": {}, + "outputs": [], + "source": [ + "normalizer = get_normalizer()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9400e6d3-ba92-442a-8dd4-117e95dce2ea", + "metadata": {}, + "outputs": [], + "source": [ + "for record in tqdm(textonly_data):\n", + " record[\"tts_text_normalized\"] = normalizer.normalize(\n", + " record[\"tts_text\"], verbose=False, punct_pre_process=True, punct_post_process=True\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "30a934b0-9b58-4bad-bb9a-ab78d81c3859", + "metadata": {}, + "source": [ + "Save manifest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1833ac15-1750-4468-88bc-2343fbabe4d8", + "metadata": {}, + "outputs": [], + "source": [ + "write_manifest(AN4_DATASET / \"train_text_manifest.json\", textonly_data)" + ] + }, + { + "cell_type": "markdown", + "id": "fa3a2371-8c78-4dd1-9605-a668adf52b4a", + "metadata": {}, + "source": [ + "### Save pretrained checkpoints" + ] + }, + { + "cell_type": "markdown", + "id": "7eb14117-8b8b-4170-ab8c-ce496522a361", + "metadata": {}, + "source": [ + "Firstly we will load pretrained models from NGC, and save as `nemo` checkpoints. \n", + "Our hybrid model will be constructed from these checkpoints." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43c5c75a-b6e0-4b3c-ad26-a07b483d84e6", + "metadata": {}, + "outputs": [], + "source": [ + "ASR_MODEL_PATH = CHECKPOINTS_DIR / \"stt_en_conformer_ctc_small_ls.nemo\"\n", + "TTS_MODEL_PATH = CHECKPOINTS_DIR / \"fastpitch.nemo\"\n", + "ENHANCER_MODEL_PATH = CHECKPOINTS_DIR / \"enhancer.nemo\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40976e22-7a7b-42b2-86a1-9eaaef4c1c22", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# asr model: stt_en_conformer_ctc_small_ls\n", + "asr_model = EncDecCTCModelBPE.from_pretrained(model_name=\"stt_en_conformer_ctc_small_ls\")\n", + "asr_model.save_to(f\"{ASR_MODEL_PATH}\")\n", + "\n", + "# tts model: ...\n", + "# TODO - NGC\n", + "# tts_model = FastPitchModel.from_pretrained(\"...\")\n", + "# tts_model.save_to(f\"{TTS_MODEL_PATH}\")\n", + "\n", + "# enhancer model: ...\n", + "# TODO - NGC\n", + "# enhancer_model = SpectrogramEnhancerModel.from_pretrained(\"...\")\n", + "# enhancer_modelh.save_to(f\"{ENHANCER_MODEL_PATH}\")" + ] + }, + { + "cell_type": "markdown", + "id": "32d1e242-0ab0-43bf-aaa0-997d284c2c1b", + "metadata": {}, + "source": [ + "### Construct hybrid ASR-TTS model " + ] + }, + { + "cell_type": "markdown", + "id": "2210eb07-6d44-44e0-a0ad-866f1e89873a", + "metadata": {}, + "source": [ + "#### Config Parameters\n", + "\n", + "`Hybrid ASR-TTS model` consists of three parts:\n", + "\n", + "* ASR model (``EncDecCTCModelBPE``, ``EncDecRNNTBPEModel`` or ``EncDecHybridRNNTCTCBPEModel``)\n", + "* TTS Mel Spectrogram Generator (currently, only `FastPitch` model is supported)\n", + "* Enhancer model (optional)\n", + "\n", + "Also, the config allows to specify text-only dataset.\n", + "\n", + "Main parts of the config:\n", + "\n", + "* ASR model\n", + " * ``asr_model_path``: path to the ASR model checkpoint (`.nemo`) file, loaded only once, then the config of the ASR model is stored in the ``asr_model`` field\n", + " * ``asr_model_type``: needed only when training from scratch. ``rnnt_bpe`` corresponds to ``EncDecRNNTBPEModel``, ``ctc_bpe`` to ``EncDecCTCModelBPE``, ``hybrid_rnnt_ctc_bpe`` to ``EncDecHybridRNNTCTCBPEModel``\n", + " * ``asr_model_fuse_bn``: fusing BatchNorm in the pretrained ASR model, can improve quality in finetuning scenario\n", + "* TTS model\n", + " * ``tts_model_path``: path to the pretrained TTS model checkpoint (`.nemo`) file, loaded only once, then the config of the model is stored in the ``tts_model`` field\n", + "* Enhancer model\n", + " * ``enhancer_model_path``: optional path to the enhancer model. Loaded only once, the config is stored in the ``enhancer_model`` field\n", + "* ``train_ds``\n", + " * ``text_data``: properties related to text-only data\n", + " * ``manifest_filepath``: path (or paths) to text-only dataset manifests\n", + " * ``speakers_filepath``: path (or paths) to the text file containing speaker ids for the multi-speaker TTS model (speakers are sampled randomly during training)\n", + " * ``min_words`` and ``max_words``: parameters to filter text-only manifests by the number of words\n", + " * ``tokenizer_workers``: number of workers for initial tokenization (when loading the data). ``num_CPUs / num_GPUs`` is a recommended value.\n", + " * ``asr_tts_sampling_technique``, ``asr_tts_sampling_temperature``, ``asr_tts_sampling_probabilities``: sampling parameters for text-only and audio-text data (if both specified). Correspond to ``sampling_technique``, ``sampling_temperature``, and ``sampling_probabilities`` parameters of the `nemo.collections.common.data.dataset.ConcatDataset`.\n", + " * all other components are similar to conventional ASR models\n", + "* ``validation_ds`` and ``test_ds`` correspond to the underlying ASR model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d6dd499-d388-4ee3-9a01-d739b16e6ad7", + "metadata": {}, + "outputs": [], + "source": [ + "# load config\n", + "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6701dc8-cb3b-44cc-aab5-fb6e2c1dadb5", + "metadata": {}, + "outputs": [], + "source": [ + "config = OmegaConf.load(\"./configs/hybrid_asr_tts.yaml\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c13b3c96-4074-415f-95d2-17569886bfcd", + "metadata": {}, + "outputs": [], + "source": [ + "NUM_EPOCHS = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c41e5e8-d926-4b83-8725-bae5a82121cf", + "metadata": {}, + "outputs": [], + "source": [ + "TTS_SPEAKERS_PATH = Path(\"./TODO\") # store somewhere" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c07c07c-cb15-4a1c-80bf-20eaffaa65d9", + "metadata": {}, + "outputs": [], + "source": [ + "config.model.asr_model_path = ASR_MODEL_PATH\n", + "config.model.tts_model_path = TTS_MODEL_PATH\n", + "config.model.enhancer_model_path = ENHANCER_MODEL_PATH\n", + "\n", + "# fuse BathNorm automatically in Conformer for better performance\n", + "config.model.asr_model_fuse_bn = True \n", + "\n", + "# training data\n", + "# constructed dataset\n", + "config.model.train_ds.text_data.manifest_filepath = str(AN4_DATASET / \"train_text_manifest.json\")\n", + "# speakers for TTS model\n", + "config.model.train_ds.text_data.speakers_filepath = f\"{TTS_SPEAKERS_PATH}\"\n", + "config.model.train_ds.manifest_filepath = None # audio-text pairs - we don't use them here\n", + "config.model.train_ds.batch_size = 8\n", + "\n", + "# validation data\n", + "config.model.validation_ds.manifest_filepath = str(AN4_DATASET / \"test_manifest.json\")\n", + "config.model.validation_ds.batch_size = 8\n", + "\n", + "config.trainer.max_epochs = NUM_EPOCHS\n", + "\n", + "config.trainer.devices = 1\n", + "config.trainer.strategy = None # use 1 device, no need for ddp strategy\n", + "\n", + "OmegaConf.resolve(config)" + ] + }, + { + "cell_type": "markdown", + "id": "8ae6cb2e-f571-4b53-8897-bb8ba0fc1146", + "metadata": {}, + "source": [ + "#### Construct trainer and ASRWithTTSModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac4ae885-dec4-4ce9-8f69-a1f35d04b08c", + "metadata": {}, + "outputs": [], + "source": [ + "trainer = pl.Trainer(**config.trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f815762-b08d-4d3c-8fd3-61afa511eab4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "hybrid_model = ASRWithTTSModel(config.model)" + ] + }, + { + "cell_type": "markdown", + "id": "ca2c1bf2-28a9-4902-9c73-d96e04b21a46", + "metadata": {}, + "source": [ + "#### Validate the model\n", + "\n", + "Expect `~17.7%` WER on the AN4 test data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffa5f5c6-0609-4f46-aa0c-747319035417", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.validate(hybrid_model)" + ] + }, + { + "cell_type": "markdown", + "id": "701ee9c7-91a1-4917-bf7d-ab26b625c7bf", + "metadata": {}, + "source": [ + "#### Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f79761c9-b882-4f14-911f-4a960ff81554", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "trainer.fit(hybrid_model)" + ] + }, + { + "cell_type": "markdown", + "id": "eac18c7c-bdcb-40ad-9c50-37f89fb4aa2a", + "metadata": {}, + "source": [ + "#### Validate the model after training\n", + "\n", + "Expect `~2%` WER on the AN4 test data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd927e87-13fb-4b61-8b4a-a6850780f605", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "trainer.validate(hybrid_model)" + ] + }, + { + "cell_type": "markdown", + "id": "6d25a77d-35ed-44b5-9ef5-318afa321acf", + "metadata": {}, + "source": [ + "### Save final model. Extract pure ASR model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f53ebd3-b89a-47e4-a0a5-ed3a3572f7c1", + "metadata": {}, + "outputs": [], + "source": [ + "# save full model: the model can be further used for finetuning\n", + "hybrid_model.save_to(\"checkpoints/finetuned_hybrid_model.nemo\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0560c2c-af28-4d8f-b36d-c18ec6a482a8", + "metadata": {}, + "outputs": [], + "source": [ + "# extract the resulting ASR model from the hybrid model\n", + "hybrid_model.save_asr_model_to(\"checkpoints/finetuned_asr_model.nemo\")" + ] + }, + { + "cell_type": "markdown", + "id": "2de58fbb-50be-42cd-9095-01cacfdb6931", + "metadata": {}, + "source": [ + "## Using Scritps (examples)" + ] + }, + { + "cell_type": "markdown", + "id": "86655198-b1fc-4615-958c-7c01f3cbd024", + "metadata": {}, + "source": [ + "`/examples/asr/asr_with_tts/` contains scripts for finetuning existing models and training new models from scratch." + ] + }, + { + "cell_type": "markdown", + "id": "b5837536-8280-475c-a581-caaee00edfca", + "metadata": {}, + "source": [ + "### Finetuning Existing Model" + ] + }, + { + "cell_type": "markdown", + "id": "84df9aeb-3b5e-41fc-a8d0-dfc660e71375", + "metadata": {}, + "source": [ + "To finetune existing ASR model using text-only data use `/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py` script with the corresponding config `/examples/asr/conf/asr_tts/hybrid_asr_tts.yaml`.\n", + "\n", + "Please specify paths to all the required models (ASR, TTS, and Enhancer checkpoints), along with `train_ds.text_data.manifest_filepath` and `train_ds.text_data.speakers_filepath`." + ] + }, + { + "cell_type": "markdown", + "id": "78b9028c-02ce-4af4-b510-a431f4a2f62b", + "metadata": {}, + "source": [ + "```shell\n", + "python speech_to_text_bpe_with_text_finetune.py \\\n", + " model.asr_model_path= \\\n", + " model.tts_model_path= \\\n", + " model.enhancer_model_path= \\\n", + " model.asr_model_fuse_bn= \\\n", + " model.train_ds.manifest_filepath= \\\n", + " model.train_ds.text_data.manifest_filepath= \\\n", + " model.train_ds.text_data.speakers_filepath= \\\n", + " model.train_ds.text_data.tokenizer_workers=4 \\\n", + " model.validation_ds.manifest_filepath= \\\n", + " model.train_ds.batch_size=\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0b17c097-a3b1-49a3-8f54-f07b94218d0b", + "metadata": {}, + "source": [ + "### Training a New Model from Scratch" + ] + }, + { + "cell_type": "markdown", + "id": "6d75b928-57b3-4180-bd09-37e018eef7ef", + "metadata": {}, + "source": [ + "```shell\n", + "python speech_to_text_bpe_with_text.py \\\n", + " # (Optional: --config-path= --config-name=) \\\n", + " ++asr_model_type= \\\n", + " ++tts_model_path= \\\n", + " ++enhancer_model_path= \\\n", + " model.tokenizer.dir= \\\n", + " model.tokenizer.type=\"bpe\" \\\n", + " model.train_ds.manifest_filepath= \\\n", + " ++model.train_ds.text_data.manifest_filepath= \\\n", + " ++model.train_ds.text_data.speakers_filepath= \\\n", + " ++model.train_ds.text_data.min_words=1 \\\n", + " ++model.train_ds.text_data.max_words=45 \\\n", + " ++model.train_ds.text_data.tokenizer_workers=4 \\\n", + " model.validation_ds.manifest_filepath= \\\n", + " model.train_ds.batch_size= \\\n", + " trainer.max_epochs= \\\n", + " trainer.num_nodes= \\\n", + " trainer.accumulate_grad_batches= \\\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "9a9a6cd3-4bdc-4b6e-b4b1-3bfd50fd01b3", + "metadata": {}, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "markdown", + "id": "e2890c61-e4b7-47aa-a086-bc483ae7141f", + "metadata": {}, + "source": [ + "The tutorial demonstrated the main concepts related to hybrid ASR-TTS models to finetune ASR models and train new ones from scratch. \n", + "The ability to achieve good text-only adaptation results is demonstrated by finetuning a small Conformer model on text-only data from the AN4 dataset." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ml38", + "language": "python", + "name": "ml38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 0894b32cd345ba94d4413b784a9ebcd17d699a7f Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Sat, 1 Jul 2023 00:08:33 +0400 Subject: [PATCH 2/6] Fix enhancer usage Signed-off-by: Vladimir Bataev --- nemo/collections/asr/models/hybrid_asr_tts_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/asr/models/hybrid_asr_tts_models.py b/nemo/collections/asr/models/hybrid_asr_tts_models.py index 8486f956c3b7..8494a093b29d 100644 --- a/nemo/collections/asr/models/hybrid_asr_tts_models.py +++ b/nemo/collections/asr/models/hybrid_asr_tts_models.py @@ -311,8 +311,10 @@ def from_pretrained_models( ) ) else: + cfg = copy.deepcopy(cfg) # copy to avoid modifying original config cfg.tts_model_path = f"{tts_model_path}" cfg.asr_model_path = f"{asr_model_path}" + cfg.enhancer_model_path = f"{enhancer_model_path}" if enhancer_model_path is not None else None return ASRWithTTSModel(cfg, trainer=trainer) def __setattr__(self, name, value): From 66704d8e90a75b505d8e766281bfad3768ce9626 Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Tue, 11 Jul 2023 11:20:56 -0700 Subject: [PATCH 3/6] Use models from NGC. Fix grammar Signed-off-by: Vladimir Bataev --- tutorials/asr/ASR_TTS_Tutorial.ipynb | 48 ++++++++++++++++++---------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/tutorials/asr/ASR_TTS_Tutorial.ipynb b/tutorials/asr/ASR_TTS_Tutorial.ipynb index 567c7df19088..7051b853b445 100644 --- a/tutorials/asr/ASR_TTS_Tutorial.ipynb +++ b/tutorials/asr/ASR_TTS_Tutorial.ipynb @@ -5,7 +5,7 @@ "id": "a3570803-9bfa-4e97-9891-5ae0759eb8ca", "metadata": {}, "source": [ - "# Hybrid ASR-TTS models tutorial" + "# Hybrid ASR-TTS Models Tutorial" ] }, { @@ -13,7 +13,7 @@ "id": "50fc294f-f319-4465-8f90-a28b49843e60", "metadata": {}, "source": [ - "This tutorial is designed to introduce using Hybrid ASR-TTS models, aka `ASRWithTTSModel`, to finetune existing ASR models using an integrated text-to-mel-spectrogram generator. " + "This tutorial is intended to introduce you to using ASR-TTS Hybrid Models, also known as `ASRWithTTSModel`, to finetune existing ASR models using an integrated text-to-mel-spectrogram generator. " ] }, { @@ -49,7 +49,7 @@ "\n", "The architecture is shown in the figure. \n", "\n", - "The model can take text or audio as input during training. In the case of audio input, a mel spectrogram is extracted as usual and passed to ASR neural network. In the case of textual input, the mel spectrogram generator produces a spectrogram on the fly from the text. The spectrogram is improved by the enhancer (if present) and fed into the ASR model. \n", + "The model can take text or audio as input during training. In the case of audio input, a mel spectrogram is extracted as usual and passed to the ASR neural network. In the case of textual input, the mel spectrogram generator produces a spectrogram on the fly from the text. The spectrogram is improved by the enhancer (if present) and fed into the ASR model. \n", "\n", "### Capabilities and limitations\n", "\n", @@ -272,7 +272,7 @@ "Text-only manifest should contain three fields:\n", "- `text`: target text for ASR model\n", "- `tts_text`: text to use as a source for TTS model (unnormalized)\n", - "- `tts_text_normalized`: text to use as a source for TTS model (unnormalized)\n", + "- `tts_text_normalized`: text to use as a source for TTS model (normalized)\n", "\n", "If `tts_text_normalized` is not present, `tts_text` will be used, normalization will be done when loading the dataset.\n", "It is highly recommended to normalize the text and create `tts_text_normalized` field manually, since current normalizers are unsuitable for processing a large amount of text on the fly." @@ -385,8 +385,12 @@ "id": "7eb14117-8b8b-4170-ab8c-ce496522a361", "metadata": {}, "source": [ - "Firstly we will load pretrained models from NGC, and save as `nemo` checkpoints. \n", - "Our hybrid model will be constructed from these checkpoints." + "Firstly we will load pretrained models from NGC, and save them as `nemo` checkpoints. \n", + "Our hybrid model will be constructed from these checkpoints.\n", + "We will use:\n", + "- small Conformer-CTC ASR model trained on LibriSpeech data (for finetuning)\n", + "- multi-speaker TTS FastPitch model is trained on LibriTTS data. Spectrogram parameters for this model are the same as those used in the ASR model\n", + "- enhancer, which is trained adversarially on the output of the TTS model and natural spectrograms" ] }, { @@ -414,15 +418,13 @@ "asr_model = EncDecCTCModelBPE.from_pretrained(model_name=\"stt_en_conformer_ctc_small_ls\")\n", "asr_model.save_to(f\"{ASR_MODEL_PATH}\")\n", "\n", - "# tts model: ...\n", - "# TODO - NGC\n", - "# tts_model = FastPitchModel.from_pretrained(\"...\")\n", - "# tts_model.save_to(f\"{TTS_MODEL_PATH}\")\n", + "# tts model: tts_en_fastpitch_for_asr_finetuning\n", + "tts_model = FastPitchModel.from_pretrained(model_name=\"tts_en_fastpitch_for_asr_finetuning\")\n", + "tts_model.save_to(f\"{TTS_MODEL_PATH}\")\n", "\n", - "# enhancer model: ...\n", - "# TODO - NGC\n", - "# enhancer_model = SpectrogramEnhancerModel.from_pretrained(\"...\")\n", - "# enhancer_modelh.save_to(f\"{ENHANCER_MODEL_PATH}\")" + "# enhancer model: tts_en_spectrogram_enhancer_for_asr_finetuning\n", + "enhancer_model = SpectrogramEnhancerModel.from_pretrained(model_name=\"tts_en_spectrogram_enhancer_for_asr_finetuning\")\n", + "enhancer_model.save_to(f\"{ENHANCER_MODEL_PATH}\")" ] }, { @@ -446,7 +448,7 @@ "* TTS Mel Spectrogram Generator (currently, only `FastPitch` model is supported)\n", "* Enhancer model (optional)\n", "\n", - "Also, the config allows to specify text-only dataset.\n", + "Also, the config allows to specify a text-only dataset.\n", "\n", "Main parts of the config:\n", "\n", @@ -500,6 +502,14 @@ "NUM_EPOCHS = 10" ] }, + { + "cell_type": "markdown", + "id": "4d090c5d-44a7-401a-a753-b8779b1c1e0b", + "metadata": {}, + "source": [ + "We will use all available speakers (sampled uniformly)." + ] + }, { "cell_type": "code", "execution_count": null, @@ -507,7 +517,11 @@ "metadata": {}, "outputs": [], "source": [ - "TTS_SPEAKERS_PATH = Path(\"./TODO\") # store somewhere" + "TTS_SPEAKERS_PATH = Path(\"./checkpoints/speakers.txt\")\n", + "\n", + "with open(TTS_SPEAKERS_PATH, \"w\", encoding=\"utf-8\") as f:\n", + " for speaker_id in range(tts_model.cfg.n_speakers):\n", + " print(speaker_id, file=f)" ] }, { @@ -760,7 +774,7 @@ "id": "9a9a6cd3-4bdc-4b6e-b4b1-3bfd50fd01b3", "metadata": {}, "source": [ - "## Conclusion" + "## Sumary" ] }, { From ea8e6f72eb15464544164a0c620f4e4b509cdaac Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Tue, 11 Jul 2023 22:21:37 +0400 Subject: [PATCH 4/6] Add tutorial to docs Signed-off-by: Vladimir Bataev --- docs/source/starthere/tutorials.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/starthere/tutorials.rst b/docs/source/starthere/tutorials.rst index 9c960053398b..3859cb630d02 100644 --- a/docs/source/starthere/tutorials.rst +++ b/docs/source/starthere/tutorials.rst @@ -106,6 +106,9 @@ To run a tutorial: * - ASR - Multi-lingual ASR - `Multi-lingual ASR `_ + * - ASR + - Hybrid ASR-TTS Models Tutorial + - `Multi-lingual ASR `_ * - NLP - Using Pretrained Language Models for Downstream Tasks - `Pretrained Language Models for Downstream Tasks `_ From b8ab5bb4b24dfc4d6668930a7185f6173b2c03a6 Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Tue, 11 Jul 2023 12:02:10 -0700 Subject: [PATCH 5/6] Add section about TTS models Signed-off-by: Vladimir Bataev --- tutorials/asr/ASR_TTS_Tutorial.ipynb | 57 ++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/tutorials/asr/ASR_TTS_Tutorial.ipynb b/tutorials/asr/ASR_TTS_Tutorial.ipynb index 7051b853b445..ea699eacf969 100644 --- a/tutorials/asr/ASR_TTS_Tutorial.ipynb +++ b/tutorials/asr/ASR_TTS_Tutorial.ipynb @@ -51,7 +51,7 @@ "\n", "The model can take text or audio as input during training. In the case of audio input, a mel spectrogram is extracted as usual and passed to the ASR neural network. In the case of textual input, the mel spectrogram generator produces a spectrogram on the fly from the text. The spectrogram is improved by the enhancer (if present) and fed into the ASR model. \n", "\n", - "### Capabilities and limitations\n", + "### Capabilities and Limitations\n", "\n", "This approach can be used to finetune the pretrained ASR model using text-only data. Training new models from scratch is also possible. The text should contain phrases and sentences and be split into sentences (~45 words maximum, corresponding to ~16.7 seconds of synthesized audio). Using only separate words is not recommended since this doesn't allow to adapt ASR model adapts to recognize new words in context. \n", "\n", @@ -69,7 +69,7 @@ "id": "2702d081-c675-4a96-8263-6059e310d048", "metadata": {}, "source": [ - "## Example: Finetuning ASR model using text-only data" + "## Example: Finetuning ASR Model Using Text-Only Data" ] }, { @@ -96,7 +96,7 @@ "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", "4. Run the following cell to set up dependencies.\n", "\n", - "NOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use." + "NOTE: The user is responsible for checking the content of datasets and the applicable licenses and determining if they are suitable for the intended use." ] }, { @@ -270,12 +270,12 @@ "metadata": {}, "source": [ "Text-only manifest should contain three fields:\n", - "- `text`: target text for ASR model\n", - "- `tts_text`: text to use as a source for TTS model (unnormalized)\n", + "- `text`: target text for the ASR model\n", + "- `tts_text`: text to use as a source for the TTS model (unnormalized)\n", "- `tts_text_normalized`: text to use as a source for TTS model (normalized)\n", "\n", - "If `tts_text_normalized` is not present, `tts_text` will be used, normalization will be done when loading the dataset.\n", - "It is highly recommended to normalize the text and create `tts_text_normalized` field manually, since current normalizers are unsuitable for processing a large amount of text on the fly." + "If `tts_text_normalized` is not present, `tts_text` will be used, and normalization will be done when loading the dataset.\n", + "It is highly recommended to normalize the text and manually create the `tts_text_normalized` field since current normalizers are unsuitable for processing a large amount of text on the fly." ] }, { @@ -326,9 +326,9 @@ "id": "dd0253aa-d7f1-47ee-a142-099b71241270", "metadata": {}, "source": [ - "Сonstruct `tts_text_normalized` field by applying English normalizer to the text.\n", + "Сonstruct the `tts_text_normalized` field by applying an English normalizer to the text.\n", "\n", - "AN4 data doesn't contain numbers, currency and other entities, so, normalizer is used here only for demonstration purpose." + "AN4 data doesn't contain numbers, currency, and other entities, so the normalizer is used here only for demonstration purposes." ] }, { @@ -385,7 +385,7 @@ "id": "7eb14117-8b8b-4170-ab8c-ce496522a361", "metadata": {}, "source": [ - "Firstly we will load pretrained models from NGC, and save them as `nemo` checkpoints. \n", + "Firstly we will load pretrained models from NGC and save them as `nemo` checkpoints. \n", "Our hybrid model will be constructed from these checkpoints.\n", "We will use:\n", "- small Conformer-CTC ASR model trained on LibriSpeech data (for finetuning)\n", @@ -769,12 +769,47 @@ "```" ] }, + { + "cell_type": "markdown", + "id": "01c17712-ae8d-49cb-ade1-ded168676e27", + "metadata": {}, + "source": [ + "## Training TTS Models for ASR Finetuning" + ] + }, + { + "cell_type": "markdown", + "id": "422dc3b2-d29f-4ed0-b4d2-6d32b35dfb7b", + "metadata": {}, + "source": [ + "### TTS Model (FastPitch)\n", + "\n", + "TTS model for the purpose of ASR model finetuning should be trained with the same mel spectrogram parameters as used in the ASR model. The typical parameters are `10ms` hop length, `25ms` window length, and the highest band of 8kHz (for 16kHz data). Other parameters are the same as for common multi-speaker TTS models.\n", + "\n", + "Mainly we observed two differences specific to TTS models for ASR:\n", + "- adding more speakers and more data improves the final ASR model quality (but not the perceptual quality of the TTS model)\n", + "- training for more epochs can also improve the quality of the ASR system (but MSE loss used for the TTS model can be higher than optimal on validation data)\n", + "\n", + "Use script `/examples/tts/fastpitch.py` to train a FastPitch model.\n", + "More details about the FastPitch model can be found in the [documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/tts/models.html#fastpitch). \n", + "\n", + "### Enhancer\n", + "Use script `/examples/tts/spectrogram_enhancer.py` to train an Enhancer model. More details can be found in the \n", + "[documentation](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/tts/models.html).\n", + "\n", + "### Models Used in This Tutorial\n", + "\n", + "Some details about the models used in this tutorial can be found on [NGC](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/tts_en_fastpitch_spectrogram_enhancer_for_asr_finetuning).\n", + "\n", + "The system is also described in detail in the paper in the paper [Text-only domain adaptation for end-to-end ASR using integrated text-to-mel-spectrogram generator](https://arxiv.org/abs/2302.14036)." + ] + }, { "cell_type": "markdown", "id": "9a9a6cd3-4bdc-4b6e-b4b1-3bfd50fd01b3", "metadata": {}, "source": [ - "## Sumary" + "## Summary" ] }, { From 260030cc84f6f6d59b00f8c01552097ea9691d19 Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Tue, 11 Jul 2023 12:03:18 -0700 Subject: [PATCH 6/6] Reformat notebook Signed-off-by: Vladimir Bataev --- tutorials/asr/ASR_TTS_Tutorial.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/asr/ASR_TTS_Tutorial.ipynb b/tutorials/asr/ASR_TTS_Tutorial.ipynb index ea699eacf969..939ef8a28d29 100644 --- a/tutorials/asr/ASR_TTS_Tutorial.ipynb +++ b/tutorials/asr/ASR_TTS_Tutorial.ipynb @@ -536,7 +536,7 @@ "config.model.enhancer_model_path = ENHANCER_MODEL_PATH\n", "\n", "# fuse BathNorm automatically in Conformer for better performance\n", - "config.model.asr_model_fuse_bn = True \n", + "config.model.asr_model_fuse_bn = True\n", "\n", "# training data\n", "# constructed dataset\n",