Skip to content

shivammehta25/BetterFastSpeech2

Repository files navigation

BetterFastSpeech 2

python pytorch lightning hydra black isort

It is the ordinary FastSpeech 2 architecture with some modifications. I just wanted to make the code base better and more readable. And finally have an open source implementation of FastSpeech 2 that doesn't sounds bad and is easier to hack and work with.

If you like this you will love Matcha-TTS

Changes from the original architecture:

  • Instead of using MFA, I obtained alignment from a pretrained Matcha-TTS model.
    • To save myself from the pain of setting up and training MFA
  • Used IPA phonemes with blanks in between phones.
  • No LR decay
  • Duration prediction in log domain
  • Everyone seems to be using the postnet from Tacotron 2; I've used it as well.

Link to LJ Speech checkpoint Running the code locally with cli will autodownload the checkpoint as well.

Installation

  1. Create an environment (suggested but optional)
conda create -n betterfs2 python=3.10 -y
conda activate betterfs2
  1. Install from source
git clone https://github.com/shivammehta25/BetterFastSpeech2.git
cd BetterFastSpeech2
pip install -e .
  1. Run CLI / gradio app / jupyter notebook
# This will download the required models
betterfs2 --text "<INPUT TEXT>"

or open synthesis.ipynb on jupyter notebook

Train with your own dataset

Let's assume we are training with LJ Speech

  1. Download the dataset from here, extract it to data/LJSpeech-1.1, and prepare the file lists to point to the extracted data like for item 5 in the setup of the NVIDIA Tacotron 2 repo.

  2. Train a Matcha-TTS model to extract durations or if you have a pretrained model, you can use that as well.

Your data directory should look like:

data/
└── LJSpeech-1.1
    ├── durations/ # Here
    ├── metadata.csv
    ├── README
    ├── test.txt
    ├── train.txt
    ├── val.txt
    └── wavs/
  1. Clone and enter the BetterFastSpeech2 repository
git clone https://github.com/shivammehta25/BetterFastSpeech2.git
cd BetterFastSpeech2 
  1. Install the package from source
pip install -e .
  1. Go to configs/data/ljspeech.yaml and change
train_filelist_path: data/LJSpeech-1.1/train.txt
valid_filelist_path: data/LJSpeech-1.1/val.txt
  1. Generate normalisation statistics with the yaml file of dataset configuration
python fs2/utils/preprocess.py -i ljspeech
# Output:
#{'pitch_min': 67.836174, 'pitch_max': 578.637146, 'pitch_mean': 207.001846, 'pitch_std': 52.747742, 'energy_min': 0.084354, 'energy_max': 190.849121, 'energy_mean': 21.330254, 'energy_std': 17.663319, 'mel_mean': -5.554245, 'mel_std': 2.059021}

Update these values in configs/data/ljspeech.yaml under data_statistics key.

data_statistics:  # Computed for ljspeech dataset
    pitch_min: 67.836174 
    pitch_max: 792.962036
    pitch_mean: 211.046158
    pitch_std: 53.012085
    energy_min: 0.023226
    energy_max: 241.037918
    energy_mean: 21.821531
    energy_std: 18.17124
    mel_mean: -5.517035
    mel_std: 2.064413

to the paths of your train and validation filelists.

  1. Run the training script
python fs2/train.py experiment=ljspeech
  • for multi-gpu training, run
python fs2/train.py experiment=ljspeech trainer.devices=[0,1]
  1. Synthesise from the custom trained model
betterfs2 --text "<INPUT TEXT>" --checkpoint_path <PATH TO CHECKPOINT>

Citation information

If you use our code or otherwise find this work useful, please cite our paper:

@inproceedings{mehta2024matcha,
  title={Matcha-{TTS}: A fast {TTS} architecture with conditional flow matching},
  author={Mehta, Shivam and Tu, Ruibo and Beskow, Jonas and Sz{\'e}kely, {\'E}va and Henter, Gustav Eje},
  booktitle={Proc. ICASSP},
  year={2024}
}

Acknowledgements

Since this code uses Lightning-Hydra-Template, you have all the powers that come with it.

Other source code we would like to acknowledge:

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published