A PyTorch implementation of A Neural Probabilistic Language Model. Code for training and data-loading based on the PyTorch example Word level language model.
To get the wikitext-2 dataset, run:
./get-data.sh
A word-level example:
./main.py train --name wiki --order 5 --batch-size 32
A character-level example:
./main.py train --name wiki-char --use-char --order 12 --emb-dim 20 --batch-size 1024
If you have pretrained GloVe vectors, you can use those:
./main.py train --name wiki --use-glove --glove-dir your/glove/dir --emb-dim 50
Some other data arguments are:
--lower # Lowercase all words in training data.
--no-headers # Remove all headers such as `=== History ===`.
With the following arguments one epoch takes around 45 minutes:
./main.py train --name wiki --order 5 --use-glove --emb-dim 50 --hidden-dims 100 \
--batch-size 128 --epochs 10 # Test perplexity 224.89
We can explore the limits:
./main.py train --name wiki --order 13 --emb-dim 100 --hidden-dims 500 \
--epochs 40 --batch-size 512 --dropout 0.5 # Test perplexity 153.12
./main.py train --name wiki --order 13 --emb-dim 300 --hidden-dims 1400 \
--epochs 40 --batch-size 256 --dropout 0.65 # Test perplexity 152.64
To generate text, use:
./main.py generate --checkpoint path/to/saved/model
The <eos>
token is replaced with a newline, and the rest is printed as is.
Other generation arguments are:
--temperature 0.9 # Temperature to manipulate distribution.
--start # Provide an optional start of the generated text (can be longer than order)
--no-unk # Do not generate unks, especially useful for low --temperature.
--no-sos # Do not print <sos> tokens
See some generated text in generate.txt.
To visualize the trained embeddings of the model, use:
./main.py plot --checkpoint path/to/saved/model
This fits a 2D t-SNE plot with K-means cluster coloring of the 1000 most common words in the dataset. The requires Bokeh for plotting and scikit-learn for t-SNE and K-means.
See an example html here. (Github does not render html files. To render, download and open, or use this link.)
python>=3.6
torch==0.3.0.post4
numpy
tqdm
- Convert to torch4
- Text generation by sampling.
- Plot embeddings with t-SNE
- Perplexity for user input.
- Softmax approximation.