Skip to content

dreamgonfly/transformer-pytorch

Repository files navigation

Transformer-pytorch

A PyTorch implementation of Transformer in "Attention is All You Need" (https://arxiv.org/abs/1706.03762)

This repo focuses on clean, readable, and modular implementation of the paper.

screen shot 2018-09-27 at 1 49 14 pm

Requirements

Usage

Prepare datasets

This repo comes with example data in data/ directory. To begin, you will need to prepare datasets with given data as follows:

$ python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed

The example data is brought from OpenNMT-py. The data consists of parallel source (src) and target (tgt) data for training and validation. A data file contains one sentence per line with tokens separated by a space. Below are the provided example data files.

  • src-train.txt
  • tgt-train.txt
  • src-val.txt
  • tgt-val.txt

Train model

To train model, provide the train script with a path to processed data and save files as follows:

$ python train.py --data_dir=data/example/processed --save_config=checkpoints/example_config.json --save_checkpoint=checkpoints/example_model.pth --save_log=logs/example.log 

This saves model config and checkpoints to given files, respectively. You can play around with hyperparameters of the model with command line arguments. For example, add --epochs=300 to set the number of epochs to 300.

Translate

To translate a sentence in source language to target language:

$ python predict.py --source="There is an imbalance here ." --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth

Candidate 0 : Hier fehlt das Gleichgewicht .
Candidate 1 : Hier fehlt das das Gleichgewicht .
Candidate 2 : Hier fehlt das das das Gleichgewicht .

It will give you translation candidates of the given source sentence. You can adjust the number of candidates with command line argument.

Evaluate

To calculate BLEU score of a trained model:

$ python evaluate.py --save_result=logs/example_eval.txt --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth

BLEU score : 0.0007947

File description

  • models.py includes Transformer's encoder, decoder, and multi-head attention.
  • embeddings.py contains positional encoding.
  • losses.py contains label smoothing loss.
  • optimizers.py contains Noam optimizer.
  • metrics.py contains accuracy metric.
  • beam.py contains beam search.
  • datasets.py has code for loading and processing data.
  • trainer.py has code for training model.
  • prepare_datasets.py processes data.
  • train.py trains model.
  • predict.py translates given source sentence with a trained model.
  • evaluate.py calculates BLEU score of a trained model.

Reference

Author

@dreamgonfly

About

A PyTorch implementation of Transformer in "Attention is All You Need"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages