- Install PyTorch
pip install pytest submitit hydra-core hydra-submitit-launcher loguru tqdm gitpython transformers lightning matplotlib datasets sortedcontainers maze-dataset pymongo numpy maze-dataset
If you want to run A* mazes (from https://github.com/facebookresearch/searchformer/)
- Install mongodb
- Download maze.gz and maze.vocabulary.gz from https://github.com/facebookresearch/searchformer/blob/main/doc/mongodb.md
- add those to your mongodb
mongorestore --gzip --archive=maze.gz
mongorestore --gzip --archive=maze.vocabulary.gz
adjust locations: search for "TODO" and you will find them:
- main.py --> code snapshot dir
- train_defaults.yaml --> logs dir
- train_defaults.yaml --> data dir
Locally
python main.py -m mode=local model=gpt dataset=maze datamodule.grid_n=4
use_wandb=False
or True to enable or disable debugging
python main.py -m mode=local model=past dataset=maze datamodule.grid_n=4
PAST is an encoder-decoder model that runs best with mlm-u (model.train_mode=absorbing). GPT is the best model for AR (left to right next token prediction)
See the CONTRIBUTING file for how to help out.
This project is Apache 2.0 licensed, as found in the LICENSE file.
The stargraph dataset has been adapted from https://github.com/gregorbachmann/Next-Token-Failures/