Last update: July 2020.
Code to accompany our paper:
Randomized Smoothing of All Shapes and Sizes
Greg Yang*, Tony Duan*, J. Edward Hu, Hadi Salman, Ilya Razenshteyn, Jerry Li.
International Conference on Machine Learning (ICML), 2020 [Paper] [Blog Post]
Notably, we outperform existing provably -robust classifiers on ImageNet and CIFAR-10.
This library implements the algorithms in our paper for computing robust radii for different smoothing distributions against different adversaries; for example, distributions of the form against adversary.
The following summarizes the (distribution, adversary) pairs covered here.
We can compare the certified robust radius each of these distributions implies at a fixed level of , the lower bound on the probability that the classifier returns the top class under noise. Here all noises are instantiated for CIFAR-10 dimensionality () and normalized to variance . Note that the first two rows below certify for the adversary while the last row certifies for the adversary and the adversary. For more details see our tutorial.ipynb
notebook.
Clone our repository and install dependencies:
git clone https://github.com/tonyduan/rs4a.git
conda create --name rs4a python=3.6
conda activate rs4a
conda install numpy matplotlib pandas seaborn
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
pip install torchnet tqdm statsmodels dfply
To reproduce our SOTA results on CIFAR-10, we need to train models over
For each value, run the following:python3 -m src.train
--model=WideResNet
--noise=Uniform
--sigma={sigma}
--experiment-name=cifar_uniform_{sigma}
python3 -m src.test
--model=WideResNet
--noise=Uniform
--sigma={sigma}
--experiment-name=cifar_uniform_{sigma}
--sample-size-cert=100000
--sample-size-pred=64
--noise-batch-size=512
The training script will train the model via data augmentation for the specified noise and level of sigma, and save the model checkpoint to a directory ckpts/experiment_name
.
The testing script will load the model checkpoint from the ckpts/experiment_name
directory, make predictions over the entire test set using the smoothed classifier, and certify the and robust radii of these predictions. Note that by default we make predictions with samples, certify with samples, and at a failure probability of .
To draw a comparison to the benchmark noises, re-run the above replacing Uniform
with Gaussian
and Laplace
. Then to plot the figures and print the table of results (for adversary), run our analysis script:
python3 -m scripts.analyze --dir=ckpts --show --adv=1
Note that other noises will need to be instantiated with the appropriate arguments when the appropriate training/testing code is invoked. For example, if we want to sample noise , we would run:
python3 -m src.train
--noise=ExpInf
--k=10
--j=100
--sigma=0.5
--experiment-name=cifar_expinf_0.5
Our pre-trained models are available.
The following commands will download all models into the pretrain/
directory.
mkdir -p pretrain
wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip
unzip -d pretrain pretrain/cifar_all.zip
wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip
unzip -d pretrain pretrain/imagenet_all.zip
ImageNet (ResNet-50): [All Models, 2.3 GB]
- Sigma=0.25: [Uniform] [Gaussian] [Laplace]
- Sigma=0.5: [Uniform] [Gaussian] [Laplace]
- Sigma=0.75: [Uniform] [Gaussian] [Laplace]
- Sigma=1.0: [Uniform] [Gaussian] [Laplace]
- Sigma=1.25: [Uniform] [Gaussian] [Laplace]
- Sigma=1.5: [Uniform]
- Sigma=1.75: [Uniform]
- Sigma=2.0: [Uniform]
- Sigma=2.25: [Uniform]
- Sigma=2.5: [Uniform]
- Sigma=2.75: [Uniform]
- Sigma=3.0: [Uniform]
- Sigma=3.25: [Uniform]
- Sigma=3.5: [Uniform]
CIFAR-10 (Wide ResNet 40-2): [All Models, 226 MB]
- Sigma=0.15: [Uniform] [Gaussian] [Laplace]
- Sigma=0.25: [Uniform] [Gaussian] [Laplace]
- Sigma=0.5: [Uniform] [Gaussian] [Laplace]
- Sigma=0.75: [Uniform] [Gaussian] [Laplace]
- Sigma=1.0: [Uniform] [Gaussian] [Laplace]
- Sigma=1.25: [Uniform] [Gaussian] [Laplace]
- Sigma=1.5: [Uniform]
- Sigma=1.75: [Uniform]
- Sigma=2.0: [Uniform]
- Sigma=2.25: [Uniform]
- Sigma=2.5: [Uniform]
- Sigma=2.75: [Uniform]
- Sigma=3.0: [Uniform]
- Sigma=3.25: [Uniform]
- Sigma=3.5: [Uniform]
By default the models above were trained with noise augmentation. We further improve upon our state-of-the-art certified accuracies using recent advances in training smoothed classifiers: (1) by using stability training (Li et al. NeurIPS 2019), and (2) by leveraging additional data using (a) pre-training on downsampled ImageNet (Hendrycks et al. NeurIPS 2019) and (b) semi-supervised self-training with data from 80 Million Tiny Images (Carmon et al. 2019). Our improved models trained with these methods are released below.
ImageNet (ResNet 50):
- Stability training: [All Models, 2.3 GB]
CIFAR-10 (Wide ResNet 40-2):
- Stability training: [All Models, 234 MB]
- Stability training + pre-training: [All Models, 236 MB]
- Stability training + semi-supervised learning: [All Models, 235 MB]
An example of pre-trained model usage is below. For more in depth example see our tutorial.ipynb
notebook.
from src.models import WideResNet
from src.noises import Uniform
from src.smooth import *
# load the model
model = WideResNet(dataset="cifar", device="cuda")
saved_dict = torch.load("pretrain/cifar_uniform_050.pt")
model.load_state_dict(saved_dict)
model.eval()
# instantiation of noise
noise = Uniform(device="cpu", dim=3072, sigma=0.5)
# training code, to generate samples
noisy_x = noise.sample(x)
# testing code, certify for L1 adversary
preds = smooth_predict_hard(model, x, noise, 64)
top_cats = preds.probs.argmax(dim=1)
prob_lb = certify_prob_lb(model, x, top_cats, 0.001, noise, 100000)
radius = noise.certify(prob_lb, adv=1)
ckpts/
is used to store experiment checkpoints and results.data/
is used to store image datasets.tables/
contains caches of pre-calculated tables of certified radii.src/
contains the main souce code.scripts/
contains the analysis and plotting code.
Within the src/
directory, the most salient files are:
-
train.py
is used to train models and save tockpts/
. -
test.py
is used to test and compute robust certificates for adversaries. -
noises/test_noises.py
is a unit test for the noises we include. Run the test withpython -m unittest src/noises/test_noises.py
Note that some tests are probabilistic and can fail occasionally. If so, rerun a few more times to make sure the failure is not persistent.
-
noises/noises.py
is a library of noises derived for randomized smoothing.