Transformer-based Text-to-Speech with complete training pipeline
A clean PyTorch implementation of a transformer-based text-to-speech system featuring encoder-decoder architecture, mel-spectrogram generation, and Griffin-Lim vocoding. Based on the Neural Speech Synthesis with Transformer Network (Li et al., 2019). Includes a complete training pipeline, multi-GPU support, and production-ready inference.
# Clone the repository
git clone <repository-url>
cd Simple-TTS
# Install dependencies
pip install torch torchaudio tensorboard pandas matplotlib scikit-learn tqdm pydubThe system is designed for the LJSpeech dataset:
# Download LJSpeech dataset
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar -xjf LJSpeech-1.1.tar.bz2
# Update paths in configs/Configs.py
csv_path = "/path/to/LJSpeech-1.1/metadata.csv"
wav_path = "/path/to/LJSpeech-1.1/wavs"from configs.Configs import Configs
from train import TrainSimpleTTS
# Initialize configuration and trainer
config = Configs()
trainer = TrainSimpleTTS(config, device="cuda:0")
# Start training
trainer.train()Or use the provided training script:
python train.pyimport torch
from configs.Configs import Configs
from simple_tts import SimpleTTS
from utils.text_to_seq import text_to_seq
from utils.melspecs import inverse_mel_spec_to_wav
from utils.write_mp3 import write_mp3
# Load configuration and model
config = Configs()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SimpleTTS(
    text_num_embeddings=config.text_num_embeddings,
    embedding_size=config.embedding_size,
    encoder_embedding_size=config.encoder_embedding_size,
    mel_freq=config.mel_freq,
    max_mel_time=config.max_mel_time,
    dim_feedforward=config.dim_feedforward,
    postnet_embedding_size=config.postnet_embedding_size,
    encoder_kernel_size=config.encoder_kernel_size,
    postnet_kernel_size=config.postnet_kernel_size
).to(device)
# Load trained weights
state = torch.load("path/to/model.pt", map_location=device)
model.load_state_dict(state["model"])
# Generate speech
text = "Hello, world! This is a test of the Simple TTS system."
mel_spec, stop_tokens = model.inference(
    text_to_seq(text).unsqueeze(0).to(device),
    max_length=200,
    stop_token_threshold=0.7
)
# Convert to audio
audio = inverse_mel_spec_to_wav(mel_spec[0].T, device)
write_mp3(audio.cpu().numpy(), "output.mp3")- EncoderPreNet: Text embedding + 3-layer CNN with batch normalization
- EncoderBlock: Multi-head self-attention + feedforward network
- DecoderPreNet: Mel-spectrogram preprocessing with linear layers
- DecoderBlock: Self-attention + cross-attention + feedforward network
- PostNet: 5-layer CNN for mel-spectrogram refinement
- TextMelDataset: Configurable dataset with caching support
- TTSLoss: Combined MSE (mel) + BCE (stop token) loss
- TrainSimpleTTS: Complete training class with mixed precision
- Device Support: Multi-GPU training with configurable device placement
All hyperparameters are centralized in configs/Configs.py:
class Configs:
    # Audio parameters
    sr = 22050
    n_fft = 2048
    mel_freq = 128
    
    # Model parameters
    embedding_size = 256
    encoder_embedding_size = 512
    dim_feedforward = 1024
    
    # Training parameters
    batch_size = 32
    lr = 2e-4
    grad_clip = 1.0- model_development.ipynb: Model inference and visualization
- MelSpectrogram_and_STFT.ipynb: Audio processing pipeline analysis
- GriffinLim_algorithm.ipynb: Griffin-Lim vocoding demonstration
Simple-TTS/
├── configs/
│   └── Configs.py              # Configuration parameters
├── simple_tts/
│   ├── EncoderBlock.py         # Transformer encoder
│   ├── DecoderBlock.py         # Transformer decoder
│   ├── EncoderPreNet.py        # Text preprocessing
│   ├── DecoderPreNet.py        # Mel preprocessing
│   ├── PostNet.py              # Mel refinement
│   └── SimpleTTS.py            # Main model
├── train/
│   └── TrainSimpleTTS.py       # Training pipeline
├── data/
│   └── TextMelDataset.py       # Dataset and collate function
├── utils/
│   ├── melspecs.py             # Mel-spectrogram processing
│   ├── text_to_seq.py          # Text preprocessing
│   ├── tts_loss.py             # Loss function
│   ├── write_mp3.py            # Audio export
│   └── mask_from_seq_lengths.py # Attention masking
├── train.py                    # Training entry point
└── README.md
- Python 3.8+
- PyTorch 1.12+
- torchaudio
- tensorboard
- pandas
- matplotlib
- scikit-learn
- tqdm
- pydub
- Start Small: Begin with a subset of the dataset to verify the pipeline
- Monitor Attention: Check attention alignments in TensorBoard
- Stop Token: Adjust stop_token_thresholdfor optimal sequence length
- Learning Rate: Use warmup and decay for stable training
- Mixed Precision: Enabled by default for faster training
- Fork the repository
- Create a feature branch
- Make your changes with proper type hints and documentation
- Add tests if applicable
- Submit a pull request
This project is licensed under the MIT License - see the LICENSE file for details.
- Based on Neural Speech Synthesis with Transformer Network (Li et al., 2019)
- Inspired by Tacotron2 for TTS-specific components
- Griffin-Lim algorithm for vocoding
- LJSpeech dataset for training and evaluation
If you use this code in your research, please cite:
@article{li2019neural,
  title={Neural Speech Synthesis with Transformer Network},
  author={Li, Naihan and Liu, Shujie and Liu, Yanqing and Zhao, Sheng and Liu, Ming and Zhou, Ming},
  journal={arXiv preprint arXiv:1809.08895},
  year={2019}
}
@misc{simple-tts,
  title={Simple TTS: Transformer-based Text-to-Speech System},
  author={Abdelrahman Seleem},
  year={2025},
  url={https://github.com/ASeleem/Simple-TTS}
}