This respository contains code to accompany the paper "Neural networks trained with SGD learn distributions of increasing complexity" [arXiv:2211.11567] by M. Refinetti, A. Ingrosso, and S. Goldt.
In this plot, we show the test accuracy of a ResNet18 evaluated on CIFAR10 during training with SGD on four different training data sets: the standard CIFAR10 training set (dark blue), and three different ``clones'' of the training set. The images of the clones were drawn from a Gaussian mixture fitted to CIFAR10, a mixture of Wasserstein GAN (WGAN) fitted to CIFAR10, and the cifar5m data set of Nakkiran et al.. The clones form a hierarchy of approximations to CIFAR10: while the Gaussian mixture captures only the first two moments of the inputs of each class correctly, the images in the WGAN and cifar5m data sets yield increasingly realistic images by capturing higher-order statistics. The ResNet18 trained on the Gaussian mixture has the same test accuracy on CIFAR10 as the baseline model, trained directly on CIFAR10, for the first 50 steps of SGD; the ResNet18 trained on cifar5m has the same error as the baseline model for about 2000 steps. This result suggests that the network trained on CIFAR10 discriminates the images using increasingly higher-order statistics during training.
The key programme to train the network on distributions of increasing complexity
is dist_inc_comp.py
. Running the programme with the --help
option
yields an overview over the options.
To train a ResNet18 on CIFAR10, simply run
python dist_inc_comp.py --model resnet18 --dataset cifar10
If instead you would like to train the ResNet18 on a Gaussian mixture, and test it on CIFAR10 (that's the green line in the plot above), call
python dist_inc_comp.py --model resnet18 --dataset cifar10 --clone gp
where gp
indicates that the clone to be used for training is the Gaussian
process. If you would like to train the model on the GAN data set or on cifar5m,
please contact Sebastian directly while we figure out a better way to share the
raw data sets.
To run the code, you will need up-to-date versions of
- pyTorch
- numpy
- scipy
- einops