A simple, unofficial implementation of MAE (Masked Autoencoders are Scalable Vision Learners) using pytorch-lightning. A PyTorch implementation by the authors can be found here.
Currently implements training on CUB, StanfordCars, STL-10 but is easily extensible to any other image dataset.
- Updated for compatibility with Pytorch 2.0 and PyTorch-Lightning 2.0. This probably breaks backwards compatibility. Created a release for the old version of the code.
- Modified parts of the training code for better conciseness and efficiency.
- Added additional features, including the option to save some validation reconstructions during training. Note: having trouble with saving reconstructions during distributed training; freezes at the end of the validation epoch.
- Retrained CUB and Cars models with new code and a stronger decoder.
- Fixed a bug in the code for generating mask indices. Retrained and updated the reconstruction figures (see below). They aren't quite as pretty now, but they make more sense.
# Clone the repository
git clone https://github.com/catalys1/mae-pytorch.git
cd mae-pytorch
# Install required libraries (inside a virtual environment preferably)
pip install -r requirements.txt
# Set up .env for path to data
echo "DATADIR=/path/to/data" > .env
Training options are provided through configuration files, handled by LightningCLI. See config/
for examples.
Train an MAE model on the CUB dataset:
python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml
Using multiple GPUs:
python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml --trainer.devices 8
Evaluate the learned representations using a linear probe. First, pretrain the model on the 100.000 samples of the 'unlabeled' split.
python train.py fit -c config/mae.yaml -c config/data/stl10_mae.yaml
Now, append a linear probe to the last layer of the frozen encoder and discard the decoder. The appended classifier is then trained on 4000 labeled samples of the 'train' split (another 1000 are used for training validation) and evaluated on the 'test' split. To do so, simply provide the path to the pretrained model checkpoint in the command below.
python linear_probe.py -c config/linear_probe.yaml -c config/data/stl10_linear_probe.yaml --model.init_args.ckpt_path <path to pretrained .ckpt>
This yields 77.96% accuracy on the test data.
Not yet implemented.
The default model uses ViT-Base for the encoder, and a small ViT (depth=6
, width=384
) for the decoder. This is smaller than the model used in the paper.
- Configuration and training is handled completely by pytorch-lightning.
- The MAE model uses the VisionTransformer from timm.
- Interface to FGVC datasets through fgvcdata.
- Configurable environment variables through python-dotenv.
Image reconstructions of CUB validation set images after training with the following command:
python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml --data.init_args.batch_size 256 --data.init_args.num_workers 12
Image reconstructions of Cars validation set images after training with the following command:
python train.py fit -c config/mae.yaml -c config/data/cars_mae.yaml --data.init_args.batch_size 256 --data.init_args.num_workers 16
Param | Setting |
---|---|
GPUs | 1xA100 |
Batch size | 256 |
Learning rate | 1.5e-4 |
LR schedule | Cosine decay |
Warmup | 10% of steps |
Training steps | 78,000 |
Training and validation loss curves for CUB.
Validation image reconstructions over the course of training.