-
Notifications
You must be signed in to change notification settings - Fork 98
/
model.py
118 lines (91 loc) · 4.42 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import tensorflow as tf
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from IPython import display
# variable initialization functions
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
class Model:
def __init__(self, x, y_):
in_dim = int(x.get_shape()[1]) # 784 for MNIST
out_dim = int(y_.get_shape()[1]) # 10 for MNIST
self.x = x # input placeholder
# simple 2-layer network
W1 = weight_variable([in_dim,50])
b1 = bias_variable([50])
W2 = weight_variable([50,out_dim])
b2 = bias_variable([out_dim])
h1 = tf.nn.relu(tf.matmul(x,W1) + b1) # hidden layer
self.y = tf.matmul(h1,W2) + b2 # output layer
self.var_list = [W1, b1, W2, b2]
# vanilla single-task loss
self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=self.y))
self.set_vanilla_loss()
# performance metrics
correct_prediction = tf.equal(tf.argmax(self.y,1), tf.argmax(y_,1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
def compute_fisher(self, imgset, sess, num_samples=200, plot_diffs=False, disp_freq=10):
# computer Fisher information for each parameter
# initialize Fisher information for most recent task
self.F_accum = []
for v in range(len(self.var_list)):
self.F_accum.append(np.zeros(self.var_list[v].get_shape().as_list()))
# sampling a random class from softmax
probs = tf.nn.softmax(self.y)
class_ind = tf.to_int32(tf.multinomial(tf.log(probs), 1)[0][0])
if(plot_diffs):
# track differences in mean Fisher info
F_prev = deepcopy(self.F_accum)
mean_diffs = np.zeros(0)
fish_gra = tf.gradients(tf.log(probs[0,class_ind]), self.var_list)
for i in range(num_samples):
# select random input image
im_ind = np.random.randint(imgset.shape[0])
# compute first-order derivatives
ders = sess.run(fish_gra, feed_dict={self.x: imgset[im_ind:im_ind+1]})
# square the derivatives and add to total
for v in range(len(self.F_accum)):
self.F_accum[v] += np.square(ders[v])
if(plot_diffs):
if i % disp_freq == 0 and i > 0:
# recording mean diffs of F
F_diff = 0
for v in range(len(self.F_accum)):
F_diff += np.sum(np.absolute(self.F_accum[v]/(i+1) - F_prev[v]))
mean_diff = np.mean(F_diff)
mean_diffs = np.append(mean_diffs, mean_diff)
for v in range(len(self.F_accum)):
F_prev[v] = self.F_accum[v]/(i+1)
plt.plot(range(disp_freq+1, i+2, disp_freq), mean_diffs)
plt.xlabel("Number of samples")
plt.ylabel("Mean absolute Fisher difference")
display.display(plt.gcf())
display.clear_output(wait=True)
# divide totals by number of samples
for v in range(len(self.F_accum)):
self.F_accum[v] /= num_samples
def star(self):
# used for saving optimal weights after most recent task training
self.star_vars = []
for v in range(len(self.var_list)):
self.star_vars.append(self.var_list[v].eval())
def restore(self, sess):
# reassign optimal weights for latest task
if hasattr(self, "star_vars"):
for v in range(len(self.var_list)):
sess.run(self.var_list[v].assign(self.star_vars[v]))
def set_vanilla_loss(self):
self.train_step = tf.train.GradientDescentOptimizer(0.1).minimize(self.cross_entropy)
def update_ewc_loss(self, lam):
# elastic weight consolidation
# lam is weighting for previous task(s) constraints
if not hasattr(self, "ewc_loss"):
self.ewc_loss = self.cross_entropy
for v in range(len(self.var_list)):
self.ewc_loss += (lam/2) * tf.reduce_sum(tf.multiply(self.F_accum[v].astype(np.float32),tf.square(self.var_list[v] - self.star_vars[v])))
self.train_step = tf.train.GradientDescentOptimizer(0.1).minimize(self.ewc_loss)