-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_ewc.py
225 lines (172 loc) · 10.3 KB
/
main_ewc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import argparse
import torch
import utils
import numpy as np
from model import Model
from copy import deepcopy
import plot
from torch.autograd import Variable
from tensorboardX import SummaryWriter
def main():
# Command Line args
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=100, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
# 1 is just for speed when testing - the original EWC paper hyperparameters are here:
# https://arxiv.org/pdf/1612.00796.pdf#section.4
# This experiment uses 100 epochs:
# https://github.com/stokesj/EWC
parser.add_argument('--epochs', type=int, default=1, metavar='N',
help='number of epochs to train (default: 1)')
# This learning rate is the same as the one used by:
# https://github.com/ariseff/overcoming-catastrophic/blob/afea2d3c9f926d4168cc51d56f1e9a92989d7af0/model.py#L114
#
# The original EWC paper hyperparameters are here:
# https://arxiv.org/pdf/1612.00796.pdf#section.4
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.1)')
# We don't want an L2 regularization penalty because https://arxiv.org/pdf/1612.00796.pdf#subsection.2.1
# (see figure 2A) shows that this would prevent the network from learning another task.
parser.add_argument('--l2-reg-penalty', type=float, default=0.0, metavar='L2',
help='l2 regularization penalty (weight decay) (default: 0.0)')
# This is the lambda (fisher multiplier) value used by:
# https://github.com/ariseff/overcoming-catastrophic/blob/master/experiment.ipynb - see In [17]
#
# some other examples:
# 400 (from https://arxiv.org/pdf/1612.00796.pdf#subsection.4.2)
# inverse of learning rate (1.0 / lr) (from https://github.com/stokesj/EWC)- see readme
parser.add_argument('--lam', type=float, default=15, metavar='LAM',
help='ewc lambda value (fisher multiplier) (default: 15)')
# only necessary if optimizer SGD with momentum is desired, hence default is 0.0
parser.add_argument('--momentum', type=float, default=0.0, metavar='M',
help='SGD momentum (default: 0.0)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed-torch', type=int, default=1, metavar='ST',
help='random seed for PyTorch (default: 1)')
parser.add_argument('--seed-numpy', type=int, default=1, metavar='SN',
help='random seed for numpy (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
# since validation set, which is drawn from training set, is size 1024, the rest of the data from the training set
# are used as the actual data on which the network is trained: 60000 - 1024 = 58976
parser.add_argument('--train-dataset-size', type=int, default=59800, metavar='TDS',
help='how many images to put in the training dataset')
# size of the validation set
#
# I got the value 1024 from:
# https://github.com/kuc2477/pytorch-ewc/blob/4a75734ef091e91a83ce82cab8b272be61af3ab6/main.py#L24
parser.add_argument('--validation-dataset-size', type=int, default=200, metavar='VDS',
help='how many images to put in the validation dataset')
# the number of samples used in computation of
# Fisher Information
parser.add_argument('--fisher-num-samples', type=int, default=200)
# size of hidden layer(s)
parser.add_argument('--hidden-size', type=int, default=50)
# number of hidden layers
# TODO implement this - currently does not actually modify network structure...
parser.add_argument('--hidden-layer-num', type=int, default=1)
# Dropout probability for hidden layers - see:
# https://arxiv.org/pdf/1612.00796.pdf#section.4
parser.add_argument('--hidden-dropout-prob', type=float, default=.5)
# Dropout probability for input layer - see:
# https://arxiv.org/pdf/1612.00796.pdf#section.4
parser.add_argument('--input-dropout-prob', type=float, default=.2)
args = parser.parse_args()
# determines if CUDA should be used - only if available AND not disabled via arguments
use_cuda = not args.no_cuda and torch.cuda.is_available()
# arguments specific to CUDA computation
# num_workers: how many subprocesses to use for data loading - if set to 0, data will be loaded in the main process
# pin_memory: if True, the DataLoader will copy tensors into CUDA pinned memory before returning them
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
# set the device on which to perform computations - later calls to .to(device) will move tensors to GPU or CPU
# based on the value determined here
device = torch.device("cuda" if use_cuda else "cpu")
# print 8 digits of precision when displaying floating point output from tensors
torch.set_printoptions(precision=8)
# set a manual seed for PyTorch random number generation
torch.manual_seed(args.seed_torch)
# set a manual seed for numpy random number generation
np.random.seed(args.seed_numpy)
# Instantiate a model that will be trained using EWC.
#
# .to(device):
# Move all parameters and buffers in the module Net to device (CPU or GPU- set above).
# Both integral and floating point values are moved.
ewc_model = Model(
args.hidden_size,
input_size=784, # 28 x 28 pixels = 784 pixels per MNIST image
output_size=10, # 10 classes - digits 0-9
ewc=True, # use EWC
lam=args.lam # the lambda (fisher multiplier) value to be used in the EWC loss formula
).to(device)
# A list of the different DataLoader objects that hold various permutations of the mnist testing dataset-
# we keep these around in a persistent list here so that we can use them to test each of the models in the
# list "models" after they are trained on the latest task's training dataset.
# For more details, see: generate_new_mnist_task() in utils.py
test_loaders = []
# the number of the task on which we are CURRENTLY training in the loop below (as opposed to a count of the number
# of tasks on which we have already trained) - e.g. when training on task 3 this value will be 3
task_count = 1
# dictionary, format {task number: size of network parameters (weights) when the network was trained on the task}
model_size_dictionaries = []
# initialize model size dictionaries
model_size_dictionaries.append({})
# keep learning tasks ad infinitum
while (True):
print(ewc_model)
# get the DataLoaders for the training, validation, and testing data
train_loader, validation_loader, test_loader = utils.generate_new_mnist_task(
args.train_dataset_size,
args.validation_dataset_size,
args.batch_size,
args.test_batch_size,
kwargs,
first_task=(task_count == 1) # if first_task is True, we won't permute the MNIST dataset.
)
# add the new test_loader for this task to the list of testing dataset DataLoaders for later re-use
# to evaluate how well the models retain accuracy on old tasks after learning new ones
#
# NOTE: this list also includes the current test_loader, which we are appending here, because we also
# need to test each network on the current task after training
test_loaders.append(test_loader)
# for each desired epoch, train the model on the latest task
for epoch in range(1, args.epochs + 1):
ewc_model.train_model(args, device, train_loader, epoch, task_count)
# update the model size dictionary
model_size_dictionaries[0].update({task_count: ewc_model.hidden_size})
# generate a dictionary mapping tasks to models of the sizes that the network was when those tasks were
# trained, containing subsets of the weights currently in the model (to mask new, post-expansion weights
# when testing on tasks for which the weights did not exist during training)
test_models = utils.generate_model_dictionary(ewc_model, model_size_dictionaries[0])
# test the model on ALL tasks trained thus far (including current task)
utils.test(test_models, device, test_loaders)
# using validation set in Fisher Information Matrix computation as specified by:
# https://github.com/ariseff/overcoming-catastrophic/blob/master/experiment.ipynb
ewc_model.estimate_fisher(device, validation_loader)
# update the ewc loss sums in the model to incorporate weights and fisher info from the task on which
# we just trained the network
ewc_model.update_ewc_sums()
# store the current fisher diagonals for use with plotting and comparative loss calculations
# using the method in model.alternative_ewc_loss()
ewc_model.task_fisher_diags.update(
{task_count: deepcopy(ewc_model.list_of_fisher_diags)})
# save the theta* ("theta star") values after training - for plotting and comparative loss calculations
# using the method in model.alternative_ewc_loss()
#
# NOTE: when I reference theta*, I am referring to the values represented by that variable in
# equation (3) at:
# https://arxiv.org/pdf/1612.00796.pdf#section.2
current_weights = []
for parameter in ewc_model.parameters():
current_weights.append(deepcopy(parameter.data.clone()))
ewc_model.task_post_training_weights.update({task_count: deepcopy(current_weights)})
# expand each of the models (SGD + DROPOUT and EWC) after task 2 training and before task 3 training...
if task_count == 20:
ewc_model.expand()
# increment the number of the current task before re-entering while loop
task_count += 1
if __name__ == '__main__':
main()