forked from CW-Huang/BayesianHypernet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment_MLP_WN.py
126 lines (94 loc) · 3.63 KB
/
experiment_MLP_WN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
"""
Created on Sun May 14 19:49:51 2017
@author: Chin-Wei
"""
from BHNs import MLPWeightNorm_BHN
from ops import load_mnist
from utils import log_normal, log_laplace
import numpy as np
def train_model(train_func,predict_func,X,Y,Xt,Yt,
lr0=0.1,lrdecay=1,bs=20,epochs=50):
print 'trainset X.shape:{}, Y.shape:{}'.format(X.shape,Y.shape)
N = X.shape[0]
records=list()
t = 0
for e in range(epochs):
if lrdecay:
lr = lr0 * 10**(-e/float(epochs-1))
else:
lr = lr0
for i in range(N/bs):
x = X[i*bs:(i+1)*bs]
y = Y[i*bs:(i+1)*bs]
loss = train_func(x,y,N,lr)
if t%100==0:
print 'epoch: {} {}, loss:{}'.format(e,t,loss)
tr_acc = (predict_func(X)==Y.argmax(1)).mean()
te_acc = (predict_func(Xt)==Yt.argmax(1)).mean()
print '\ttrain acc: {}'.format(tr_acc)
print '\ttest acc: {}'.format(te_acc)
t+=1
records.append(loss)
return records
def evaluate_model(predict_proba,X,Y,Xt,Yt,n_mc=1000):
n = X.shape[0]
MCt = np.zeros((n_mc,n,10))
MCv = np.zeros((n_mc,n,10))
for i in range(n_mc):
MCt[i] = predict_proba(X)
MCv[i] = predict_proba(Xt)
Y_pred = MCt.mean(0).argmax(-1)
Y_true = Y.argmax(-1)
Yt_pred = MCv.mean(0).argmax(-1)
Yt_true = Yt.argmax(-1)
tr = np.equal(Y_pred,Y_true).mean()
va = np.equal(Yt_pred,Yt_true).mean()
print "train perf=", tr
print "valid perf=", va
ind_positive = np.arange(Xt.shape[0])[Yt_pred == Yt_true]
ind_negative = np.arange(not Xt.shape[0])[Yt_pred != Yt_true]
ind = ind_negative[0] #TO-DO: complete evaluation
for ii in range(15):
print np.round(MCt[ii][ind] * 1000)
#def main():
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
# boolean: 1 -> True ; 0 -> False
parser.add_argument('--coupling',default=0,type=int)
parser.add_argument('--perdatapoint',default=0,type=int)
parser.add_argument('--lrdecay',default=0,type=int)
parser.add_argument('--lr0',default=0.1,type=float)
parser.add_argument('--lbda',default=1,type=float)
parser.add_argument('--size',default=10000,type=int)
parser.add_argument('--bs',default=20,type=int)
parser.add_argument('--epochs',default=50,type=int)
parser.add_argument('--prior',default='log_normal',type=str)
args = parser.parse_args()
print args
coupling = args.coupling
perdatapoint = args.perdatapoint
lrdecay = args.lrdecay
lr0 = args.lr0
lbda = np.cast['float32'](args.lbda)
bs = args.bs
epochs = args.epochs
if args.prior=='log_normal':
prior = log_normal
elif args.prior=='log_laplace':
prior = log_laplace
size = max(10,min(50000,args.size))
filename = '/data/lisa/data/mnist.pkl.gz'
train_x, train_y, valid_x, valid_y, test_x, test_y = load_mnist(filename)
model = MLPWeightNorm_BHN(lbda=lbda,
perdatapoint=perdatapoint,
prior=prior,
coupling=coupling)
recs = train_model(model.train_func,model.predict,
train_x[:size],train_y[:size],
valid_x,valid_y,
lr0,lrdecay,bs,epochs)
evaluate_model(model.predict_proba,
train_x[:size],train_y[:size],
valid_x,valid_y)