Skip to content

Latest commit

 

History

History
56 lines (46 loc) · 2.43 KB

README.md

File metadata and controls

56 lines (46 loc) · 2.43 KB

Learning distributions of increasing complexity

This respository contains code to accompany the paper "Neural networks trained with SGD learn distributions of increasing complexity" [arXiv:2211.11567] by M. Refinetti, A. Ingrosso, and S. Goldt.

In a nutshell

Learning distributions of increasing complexity

In this plot, we show the test accuracy of a ResNet18 evaluated on CIFAR10 during training with SGD on four different training data sets: the standard CIFAR10 training set (dark blue), and three different ``clones'' of the training set. The images of the clones were drawn from a Gaussian mixture fitted to CIFAR10, a mixture of Wasserstein GAN (WGAN) fitted to CIFAR10, and the cifar5m data set of Nakkiran et al.. The clones form a hierarchy of approximations to CIFAR10: while the Gaussian mixture captures only the first two moments of the inputs of each class correctly, the images in the WGAN and cifar5m data sets yield increasingly realistic images by capturing higher-order statistics. The ResNet18 trained on the Gaussian mixture has the same test accuracy on CIFAR10 as the baseline model, trained directly on CIFAR10, for the first 50 steps of SGD; the ResNet18 trained on cifar5m has the same error as the baseline model for about 2000 steps. This result suggests that the network trained on CIFAR10 discriminates the images using increasingly higher-order statistics during training.

Usage

The key programme to train the network on distributions of increasing complexity is dist_inc_comp.py. Running the programme with the --help option yields an overview over the options.

To train a ResNet18 on CIFAR10, simply run

python dist_inc_comp.py --model resnet18 --dataset cifar10

If instead you would like to train the ResNet18 on a Gaussian mixture, and test it on CIFAR10 (that's the green line in the plot above), call

python dist_inc_comp.py --model resnet18 --dataset cifar10 --clone gp

where gp indicates that the clone to be used for training is the Gaussian process. If you would like to train the model on the GAN data set or on cifar5m, please contact Sebastian directly while we figure out a better way to share the raw data sets.

Requirements

To run the code, you will need up-to-date versions of

  • pyTorch
  • numpy
  • scipy
  • einops