This is a PyTorch Lightning implementation of the Neural Turing Machine (NTM). For more details on NTM please see the paper.
Pytorch lightning is the lightweight PyTorch wrapper. It organises your code neatly, abstracts away all the complicated and error prone engineering, is device agnostic while it still gives you all the flexibility of a standard PyTorch training procedure.
For more information on PyTorch Lightning, see the documentation.
This repository is a PyTorch Lighting conversion of this PyTorch NTM implementation. We extend the available implementation with the LSTM network as baseline comparison. We can divide the repository in three main parts:
run_train.py
is the Lightning trainer which runs the training loop and logs the outputs.data_copytask.py
is the Lightning dataset for the copy task in the original paper. We do not implement the copy-repeat task but this could be done similar to the original PyTorch repository.model.py
is the Lightning model which specifies the training and validation loop. Within this model we call the different models which are:
model_ntm.py
which is the NTM implementation. The remaining files are in the folderntm/*
. This is a copy of the files from the original repository. Credits go to these authors.- 'model_lstm.py' which is the LSTM baseline implementation.
Note that we are generating training and validation sequences on the fly for each epoch differently.
Setup the environment
pip install -r requirements.txt
To run a model, call
python run_train.py --model MODELNAME
with MODELNAME either ntm or lstm.
You can add any number of Lightning specific options e.g.
python run_train.py --model ntm --gpus 1 --fast_dev_run True
runs the ntm model on a single GPU but it only does one fast test run to check all parts of the code.
In this part we present some results that we obtained for the copy task. The goal of the copy task is to test the ability to store and remember arbitrary long sequences. The input is a sequence random length (between 1 and 20) with a given number of bits (8) followed by a delimiter bit. E.g. we may obtain an input sequence of 20 by 8 which we want to store and remember at the output.
We run both networks over 10 seeds using the bash command multiple_train.sh
. See the options within the scripts for
the exact training options used in our scenario. Note that we use a batch size of 8 to speed up training compared to a
batch size of 1 in the original paper. We show mean and std values for training and validation data.
The individual validation costs are given by the following figures. Top is for NTM and bottom for LSTM.