-
Notifications
You must be signed in to change notification settings - Fork 1
/
cnn_sgnht.py
105 lines (94 loc) · 4.26 KB
/
cnn_sgnht.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import platform
print('python_version ==', platform.python_version())
import torch
print('torch.__version__ ==', torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import time
import argparse
import numpy as np
import math
from sgnht import *
from evaluation import *
from model_zoo import *
'''set up hyperparameters of the experiments'''
parser = argparse.ArgumentParser(description='sgnht on CNN tested on CIFAR10 appending noise')
parser.add_argument('--train-batch-size', type=int, default=64)
parser.add_argument('--test-batch-size', type=int, default=10000)
parser.add_argument('--num-burn-in', type=int, default=30000)
parser.add_argument('--num-epochs', type=int, default=1000)
parser.add_argument('--evaluation-interval', type=int, default=50)
parser.add_argument('--eta-theta', type=float, default=1.7e-8)
parser.add_argument('--c-theta', type=float, default=0.01)
parser.add_argument('--prior-precision', type=float, default=1e-3)
parser.add_argument('--permutation', type=float, default=0.2)
parser.add_argument('--enable-cuda', action='store_true')
parser.add_argument('--device-num', type=int, default=3)
args = parser.parse_args()
print (args)
if torch.cuda.is_available():
torch.cuda.set_device(args.device_num)
'''load dataset'''
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./cifar10-dataset', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.train_batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./cifar10-dataset', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=False, drop_last=True)
if __name__ == '__main__':
model = CNN()
cuda_availability = args.enable_cuda and torch.cuda.is_available()
N = len(train_loader.dataset)
num_labels = model.outputdim
sampler = SGNHT(model, N, args.eta_theta, args.c_theta)
if cuda_availability:
model.cuda()
print(model)
nIter = 0
tStart = time.time()
estimator = FullyBayesian((len(test_loader.dataset), num_labels),\
model,\
test_loader,\
cuda_availability)
acc = 0
sampler.resample_momenta()
for epoch in range(1, 1 + args.num_epochs):
print ("#######################################################################################")
print ("This is the epoch: ", epoch)
print ("#######################################################################################")
for i, (x, y) in enumerate(train_loader):
batch_size = x.data.size(0)
if args.permutation > 0.0:
y = y.clone()
y.data[:int(args.permutation*batch_size)] = torch.LongTensor(np.random.choice(num_labels, int(args.permutation*batch_size)))
if cuda_availability:
x, y = x.cuda(), y.cuda()
model.zero_grad()
model.train()
yhat = model(x)
loss = F.cross_entropy(yhat, y)
for param in model.parameters():
loss += args.prior_precision * torch.sum(param**2)
loss.backward()
'''update position and momentum'''
sampler.update()
nIter += 1
'''take the point and resample the particles'''
if nIter%args.evaluation_interval == 0:
print('loss:{:6.4f}; thermostats:{:6.3f}; tElapsed:{:6.3f}'.format(loss.data.item(),\
sampler.get_z_theta(),\
time.time() - tStart))
if nIter >= args.num_burn_in:
acc = estimator.evaluation()
print ('This is the accuracy: %{:6.2f}'.format(acc))
sampler.resample_momenta()
tStart = time.time()