-
Notifications
You must be signed in to change notification settings - Fork 0
/
GEMModel.py
215 lines (186 loc) · 7.29 KB
/
GEMModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import quadprog
# Auxiliary functions useful for GEM's inner optimization.
def compute_offsets(task, nc_per_task, is_cifar):
"""
Compute offsets for cifar to determine which
outputs to select for a given task.
"""
if is_cifar:
offset1 = task * nc_per_task
offset2 = (task + 1) * nc_per_task
else:
offset1 = 0
offset2 = nc_per_task
return offset1, offset2
def store_grad(pp, grads, grad_dims, tid):
"""
This stores parameter gradients of past tasks.
pp: parameters
grads: gradients
grad_dims: list with number of parameters per layers
tid: task id
"""
# store the gradients
grads[:, tid].fill_(0.0)
cnt = 0
for param in pp():
if param.grad is not None:
beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
en = sum(grad_dims[:cnt + 1])
grads[beg: en, tid].copy_(param.grad.data.view(-1))
cnt += 1
def overwrite_grad(pp, newgrad, grad_dims):
"""
This is used to overwrite the gradients with a new gradient
vector, whenever violations occur.
pp: parameters
newgrad: corrected gradient
grad_dims: list storing number of parameters at each layer
"""
cnt = 0
for param in pp():
if param.grad is not None:
beg = 0 if cnt == 0 else sum(grad_dims[:cnt])
en = sum(grad_dims[:cnt + 1])
this_grad = newgrad[beg: en].contiguous().view(
param.grad.data.size())
param.grad.data.copy_(this_grad)
cnt += 1
def project2cone2(gradient, memories, margin=0.5):
"""
Solves the GEM dual QP described in the paper given a proposed
gradient "gradient", and a memory of task gradients "memories".
Overwrites "gradient" with the final projected update.
input: gradient, p-vector
input: memories, (t * p)-vector
output: x, p-vector
"""
memories_np = memories.cpu().t().double().numpy()
gradient_np = gradient.cpu().contiguous().view(-1).double().numpy()
t = memories_np.shape[0]
P = np.dot(memories_np, memories_np.transpose())
P = 0.5 * (P + P.transpose())
q = np.dot(memories_np, gradient_np) * -1
G = np.eye(t)
h = np.zeros(t) + margin
v = quadprog.solve_qp(P, q, G, h)[0]
x = np.dot(v, memories_np) + gradient_np
gradient.copy_(torch.Tensor(x).view(-1, 1))
class Net(nn.Module):
def __init__(self,
n_inputs,
n_outputs,
n_tasks,
args):
super(Net, self).__init__()
nl, nh = args.n_layers, args.n_hiddens
self.margin = args.memory_strength
self.is_cifar = (args.data_file == 'cifar100.pt')
#if self.is_cifar:
#self.net = ResNet18(n_outputs)
#else:
#self.net = MLP([n_inputs] + [nh] * nl + [n_outputs])
# TODO define self.net
self.ce = nn.CrossEntropyLoss()
self.n_outputs = n_outputs
self.opt = optim.SGD(self.parameters(), args.lr)
self.n_memories = args.n_memories
self.gpu = args.cuda
# allocate episodic memory
self.memory_data = torch.FloatTensor(
n_tasks, self.n_memories, n_inputs)
self.memory_labs = torch.LongTensor(n_tasks, self.n_memories)
if args.cuda:
self.memory_data = self.memory_data.cuda()
self.memory_labs = self.memory_labs.cuda()
# allocate temporary synaptic memory
self.grad_dims = []
for param in self.parameters():
self.grad_dims.append(param.data.numel())
self.grads = torch.Tensor(sum(self.grad_dims), n_tasks)
if args.cuda:
self.grads = self.grads.cuda()
# allocate counters
self.observed_tasks = []
self.old_task = -1
self.mem_cnt = 0
if self.is_cifar:
self.nc_per_task = int(n_outputs / n_tasks)
else:
self.nc_per_task = n_outputs
def forward(self, x, t):
output = self.net(x)
if self.is_cifar:
# make sure we predict classes within the current task
offset1 = int(t * self.nc_per_task)
offset2 = int((t + 1) * self.nc_per_task)
if offset1 > 0:
output[:, :offset1].data.fill_(-10e10)
if offset2 < self.n_outputs:
output[:, offset2:self.n_outputs].data.fill_(-10e10)
return output
def observe(self, x, t, y):
# update memory
if t != self.old_task:
self.observed_tasks.append(t)
self.old_task = t
# Update ring buffer storing examples from current task
bsz = y.data.size(0)
endcnt = min(self.mem_cnt + bsz, self.n_memories)
effbsz = endcnt - self.mem_cnt
self.memory_data[t, self.mem_cnt: endcnt].copy_(
x.data[: effbsz])
if bsz == 1:
self.memory_labs[t, self.mem_cnt] = y.data[0]
else:
self.memory_labs[t, self.mem_cnt: endcnt].copy_(
y.data[: effbsz])
self.mem_cnt += effbsz
if self.mem_cnt == self.n_memories:
self.mem_cnt = 0
# compute gradient on previous tasks
if len(self.observed_tasks) > 1:
for tt in range(len(self.observed_tasks) - 1):
self.zero_grad()
# fwd/bwd on the examples in the memory
past_task = self.observed_tasks[tt]
offset1, offset2 = compute_offsets(past_task, self.nc_per_task,
self.is_cifar)
ptloss = self.ce(
self.forward(
Variable(self.memory_data[past_task]),
past_task)[:, offset1: offset2],
Variable(self.memory_labs[past_task] - offset1))
ptloss.backward()
store_grad(self.parameters, self.grads, self.grad_dims,
past_task)
# now compute the grad on the current minibatch
self.zero_grad()
offset1, offset2 = compute_offsets(t, self.nc_per_task, self.is_cifar)
loss = self.ce(self.forward(x, t)[:, offset1: offset2], y - offset1)
loss.backward()
# check if gradient violates constraints
if len(self.observed_tasks) > 1:
# copy gradient
store_grad(self.parameters, self.grads, self.grad_dims, t)
indx = torch.cuda.LongTensor(self.observed_tasks[:-1]) if self.gpu \
else torch.LongTensor(self.observed_tasks[:-1])
dotp = torch.mm(self.grads[:, t].unsqueeze(0),
self.grads.index_select(1, indx))
if (dotp < 0).sum() != 0:
project2cone2(self.grads[:, t].unsqueeze(1),
self.grads.index_select(1, indx), self.margin)
# copy gradients back
overwrite_grad(self.parameters, self.grads[:, t],
self.grad_dims)
self.opt.step()