-
Notifications
You must be signed in to change notification settings - Fork 18
/
neural.py
461 lines (390 loc) · 19 KB
/
neural.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
import argparse
import lasagne
import numpy as np
import os
import theano
import theano.tensor as T
try:
import theano.gpuarray.basic_ops as G
except ImportError:
import theano.sandbox.cuda.basic_ops as G
import time
from collections import Sequence, OrderedDict
from lasagne.layers import get_output, get_all_params
from lasagne.updates import total_norm_constraint
from theano.compile import MonitorMode
from theano.printing import pydotprint
from helpers import apply_nan_suppression
from vectorizers import BucketsVectorizer, RawVectorizer # NOQA: pickle backwards compatibility
from vectorizers import SymbolVectorizer, SequenceVectorizer # NOQA: pickle backwards compatibility
from stanza.monitoring import progress, summary
from stanza.research import config
from stanza.research.learner import Learner
from stanza.research.rng import get_rng
ColorVectorizer = BucketsVectorizer # pickle backwards compatibility
parser = config.get_options_parser()
parser.add_argument('--train_iters', type=int, default=10,
help='Number of iterations')
parser.add_argument('--train_epochs', type=int, default=100,
help='Number of epochs per iteration')
parser.add_argument('--batch_size', type=int, default=128,
help='Number of examples per minibatch for training and evaluation')
parser.add_argument('--detect_nans', type=config.boolean, default=False,
help='If True, throw an error if a non-finite value is detected.')
parser.add_argument('--verbosity', type=int, default=4,
help='Amount of diagnostic output to produce. 0-1: only progress updates; '
'2-3: plus major experiment steps; '
'4-5: plus compilation and graph assembly steps; '
'6-7: plus parameter names for each function compilation; '
'8: plus shapes and types for each compiled function call; '
'9-10: plus vectorization of all datasets')
parser.add_argument('--no_graphviz', type=config.boolean, default=False,
help='If `True`, do not use theano.printing.pydotprint to visualize '
'function graphs.')
parser.add_argument('--no_nan_suppression', type=config.boolean, default=False,
help='If `True`, do not try to suppress NaNs in training.')
parser.add_argument('--monitor_grads', type=config.boolean, default=False,
help='If `True`, return gradients for monitoring and write them to the '
'TensorBoard events file.')
parser.add_argument('--monitor_params', type=config.boolean, default=False,
help='If `True`, write parameter value histograms out to the '
'TensorBoard events file.')
parser.add_argument('--true_grad_clipping', type=float, default=5.0,
help='The maximum absolute value of all gradients. This gradient '
'clipping is performed on the full gradient calculation, not '
'just the messages passing through the LSTM.')
NONLINEARITIES = {
name: func
for name, func in lasagne.nonlinearities.__dict__.iteritems()
if name.islower() and not name.startswith('__')
}
del NONLINEARITIES['theano']
OPTIMIZERS = {
name: func
for name, func in lasagne.updates.__dict__.iteritems()
if (name in lasagne.updates.__all__ and
not name.startswith('apply_') and not name.endswith('_constraint'))
}
CELLS = {
name[:-len('Layer')]: func
for name, func in lasagne.layers.recurrent.__dict__.iteritems()
if (name in lasagne.layers.recurrent.__all__ and name.endswith('Layer') and
name != 'CustomRecurrentLayer')
}
rng = get_rng()
lasagne.random.set_rng(rng)
def detect_nan(i, node, fn):
if not isinstance(node.op, (T.AllocEmpty, T.IncSubtensor,
G.GpuAllocEmpty, G.GpuIncSubtensor)):
for output in fn.outputs:
if (not isinstance(output[0], np.random.RandomState) and
not np.isfinite(output[0]).all()):
print('*** NaN detected ***')
theano.printing.debugprint(node)
print('Inputs : %s' % [input[0] for input in fn.inputs])
print('Outputs: %s' % [output[0] for output in fn.outputs])
raise AssertionError
def sample(a, temperature=1.0):
# helper function to sample an index from a probability array
a = np.array(a)
if len(a.shape) < 1:
raise ValueError('scalar is not a valid probability distribution')
elif len(a.shape) == 1:
# Cast to higher resolution to try to get high-precision normalization
a = np.exp(np.log(a) / temperature).astype(np.float64)
a /= np.sum(a)
return np.argmax(rng.multinomial(1, a, 1))
else:
return np.array([sample(s, temperature) for s in a])
class Unpicklable(object):
def __init__(self, name):
self.name = name
def __repr__(self):
return '<%s removed in pickling>' % (self.name,)
class SimpleLasagneModel(object):
def __init__(self, input_vars, target_vars, l_out, loss,
optimizer, learning_rate=0.001, id=None):
if not isinstance(input_vars, Sequence):
raise ValueError('input_vars should be a sequence, instead got %s' % (input_vars,))
if not isinstance(target_vars, Sequence):
raise ValueError('target_vars should be a sequence, instead got %s' % (input_vars,))
self.get_options()
self.input_vars = input_vars
self.l_out = l_out
self.loss = loss
self.optimizer = optimizer
self.id = id
id_tag = (self.id + '/') if self.id else ''
id_tag_log = (self.id + ': ') if self.id else ''
if self.options.verbosity >= 6:
output_model_structure(l_out)
params = self.params()
(monitored,
train_loss_grads,
synth_vars) = self.get_train_loss(target_vars, params)
self.monitored_tags = monitored.keys()
if self.options.true_grad_clipping:
scaled_grads = total_norm_constraint(train_loss_grads, self.options.true_grad_clipping)
else:
scaled_grads = train_loss_grads
updates = optimizer(scaled_grads, params, learning_rate=learning_rate)
if not self.options.no_nan_suppression:
# TODO: print_mode='all' somehow is always printing, even when
# there are no NaNs. But tests are passing, even on GPU!
updates = apply_nan_suppression(updates, print_mode='none')
if self.options.detect_nans:
mode = MonitorMode(post_func=detect_nan)
else:
mode = None
if self.options.verbosity >= 2:
print(id_tag_log + 'Compiling training function')
params = input_vars + target_vars + synth_vars
if self.options.verbosity >= 6:
print('params = %s' % (params,))
self.train_fn = theano.function(params, monitored.values(),
updates=updates, mode=mode,
name=id_tag + 'train', on_unused_input='warn')
if self.options.run_dir and not self.options.no_graphviz:
self.visualize_graphs({'loss': monitored['loss']},
out_dir=self.options.run_dir)
test_prediction = get_output(l_out, deterministic=True)
if self.options.verbosity >= 2:
print(id_tag_log + 'Compiling prediction function')
if self.options.verbosity >= 6:
print('params = %s' % (input_vars,))
self.predict_fn = theano.function(input_vars, test_prediction, mode=mode,
name=id_tag + 'predict', on_unused_input='ignore')
if self.options.run_dir and not self.options.no_graphviz:
self.visualize_graphs({'test_prediction': test_prediction},
out_dir=self.options.run_dir)
def visualize_graphs(self, monitored, out_dir):
id_tag = (self.id + '.') if self.id else ''
for tag, graph in monitored.iteritems():
tag = tag.replace('/', '.')
pydotprint(graph, outfile=os.path.join(out_dir, id_tag + tag + '.svg'),
format='svg', var_with_name_simple=True)
def params(self):
return get_all_params(self.l_out, trainable=True)
def get_train_loss(self, target_vars, params):
assert len(target_vars) == 1
prediction = get_output(self.l_out)
mean_loss = self.loss(prediction, target_vars[0]).mean()
monitored = [('loss', mean_loss)]
grads = T.grad(mean_loss, params)
if self.options.monitor_grads:
for p, grad in zip(params, grads):
monitored.append(('grad/' + p.name, grad))
return OrderedDict(monitored), grads, []
def fit(self, Xs, ys, batch_size, num_epochs, summary_writer=None, step=0):
if not isinstance(Xs, Sequence):
raise ValueError('Xs should be a sequence, instead got %s' % (Xs,))
if not isinstance(ys, Sequence):
raise ValueError('ys should be a sequence, instead got %s' % (ys,))
history = OrderedDict((tag, []) for tag in self.monitored_tags)
id_tag = (self.id + '/') if self.id else ''
params = self.params()
progress.start_task('Epoch', num_epochs)
epoch_start = time.time()
for epoch in range(num_epochs):
progress.progress(epoch)
history_epoch = OrderedDict((tag, []) for tag in self.monitored_tags)
num_minibatches_approx = len(ys[0]) // batch_size + 1
progress.start_task('Minibatch', num_minibatches_approx)
for i, batch in enumerate(self.minibatches(Xs, ys, batch_size, shuffle=True)):
progress.progress(i)
if self.options.verbosity >= 8:
print('types: %s' % ([type(v) for t in batch for v in t],))
print('shapes: %s' % ([v.shape for t in batch for v in t],))
inputs, targets, synth = batch
monitored = self.train_fn(*inputs + targets + synth)
for tag, value in zip(self.monitored_tags, monitored):
if self.options.verbosity >= 10:
print('%s: %s' % (tag, value))
history_epoch[tag].append(value)
progress.end_task()
for tag, values in history_epoch.items():
values_array = np.array([np.asarray(v) for v in values])
history[tag].append(values_array)
mean_values = np.mean(values_array, axis=0)
if len(mean_values.shape) == 0:
summary_writer.log_scalar(step + epoch, tag, mean_values)
else:
summary_writer.log_histogram(step + epoch, tag, mean_values)
if self.options.monitor_params:
for param in params:
val = param.get_value()
tag = 'param/' + param.name
if len(val.shape) == 0:
summary_writer.log_scalar(step + epoch, tag, val)
else:
summary_writer.log_histogram(step + epoch, tag, val)
epoch_end = time.time()
examples_per_sec = len(ys[0]) / (epoch_end - epoch_start)
summary_writer.log_scalar(step + epoch,
id_tag + 'examples_per_sec', examples_per_sec)
epoch_start = epoch_end
progress.end_task()
return history
def predict(self, Xs):
if not isinstance(Xs, Sequence):
raise ValueError('Xs should be a sequence, instead got %s' % (Xs,))
id_tag_log = (self.id + ': ') if self.id else ''
if self.options.verbosity >= 8:
print(id_tag_log + 'predict shapes: %s' % [x.shape for x in Xs])
return self.predict_fn(*Xs)
def minibatches(self, inputs, targets, batch_size, shuffle=False):
'''Lifted mostly verbatim from iterate_minibatches in
https://github.com/Lasagne/Lasagne/blob/master/examples/mnist.py'''
num_examples = len(targets[0])
assert all(len(X) == num_examples for X in inputs), \
repr([type(X) for X in inputs] + [type(y) for y in targets])
assert all(len(y) == num_examples for y in targets), \
repr([type(X) for X in inputs] + [type(y) for y in targets])
if shuffle:
indices = np.arange(num_examples)
rng.shuffle(indices)
last_batch = max(0, num_examples - batch_size)
for start_idx in range(0, last_batch + 1, batch_size):
if shuffle:
excerpt = indices[start_idx:start_idx + batch_size]
else:
excerpt = slice(start_idx, start_idx + batch_size)
yield [X[excerpt] for X in inputs], [y[excerpt] for y in targets], []
def __getstate__(self):
state = dict(self.__dict__)
state['loss'] = Unpicklable('loss')
state['l_out'] = Unpicklable('l_out')
return state
def __setstate__(self, state):
self.__dict__.update(state)
self.get_options()
def get_options(self):
if not hasattr(self, 'options'):
options = config.options()
self.options = argparse.Namespace(**options.__dict__)
def output_model_structure(layer, indent=0):
print('%s%s %s' % (' ' * indent, layer.name, type(layer)))
if hasattr(layer, 'input_layers'):
for inp in layer.input_layers:
output_model_structure(inp, indent=indent + 1)
elif hasattr(layer, 'input_layer'):
output_model_structure(layer.input_layer, indent=indent + 1)
class NeuralLearner(Learner):
'''
A base class for Lasagne-based learners.
'''
def __init__(self, id=None):
super(NeuralLearner, self).__init__()
self.id = id
self.get_options()
def train(self, training_instances, validation_instances=None, metrics=None):
id_tag = (self.id + ': ') if self.id else ''
if self.options.verbosity >= 2:
print(id_tag + 'Training priors')
self.train_priors(training_instances, listener_data=self.options.listener)
self.dataset = training_instances
xs, ys = self._data_to_arrays(training_instances, init_vectorizer=True)
self._build_model()
if self.options.verbosity >= 2:
print(id_tag + 'Training conditional model')
summary_path = config.get_file_path('losses.tfevents')
if summary_path:
writer = summary.SummaryWriter(summary_path)
else:
writer = None
progress.start_task('Iteration', self.options.train_iters)
for iteration in range(self.options.train_iters):
progress.progress(iteration)
self.model.fit(xs, ys, batch_size=self.options.batch_size,
num_epochs=self.options.train_epochs,
summary_writer=writer, step=iteration * self.options.train_epochs)
validation_results = self.validate(validation_instances, metrics, iteration=iteration)
if writer is not None:
step = (iteration + 1) * self.options.train_epochs
self.on_iter_end(step, writer)
for key, value in validation_results.iteritems():
tag = 'val/' + key.split('.', 1)[1].replace('.', '/')
writer.log_scalar(step, tag, value)
writer.flush()
progress.end_task()
def on_iter_end(self, step, writer):
pass
def params(self):
return self.model.params()
@property
def num_params(self):
all_params = self.params()
return sum(np.prod(p.get_value().shape) for p in all_params)
def log_prior_emp(self, input_vars):
return self.prior_emp.apply(input_vars)
def log_prior_smooth(self, input_vars):
return self.prior_smooth.apply(input_vars)
def sample(self, inputs):
return self.predict(inputs, random=True, verbosity=-6)
def sample_prior_emp(self, num_samples):
indices = rng.randint(len(self.dataset), size=num_samples)
return [self.dataset[i].stripped() for i in indices]
def sample_joint_emp(self, num_samples=1):
input_insts = self.sample_prior_emp(num_samples)
outputs = self.sample(input_insts)
for inst, out in zip(input_insts, outputs):
inst.output = out
return input_insts
def sample_joint_smooth(self, num_samples=1):
input_insts = self.sample_prior_smooth(num_samples)
outputs = self.sample(input_insts)
for inst, out in zip(input_insts, outputs):
inst.output = out
return input_insts
def log_joint_smooth(self, input_vars, target_var):
return (self.log_prior_smooth(input_vars) -
self.loss_out(input_vars, target_var))
def log_joint_emp(self, input_vars, target_var):
return (self.log_prior_emp(input_vars) -
self.loss_out(input_vars, target_var))
def loss_out(self, input_vars=None, target_var=None):
if input_vars is None:
input_vars = self.model.input_vars
if target_var is None:
target_var = self.model.target_var
pred = get_output(self.l_out, dict(zip(self.input_layers, input_vars)))
return self.loss(pred, target_var)
def __getstate__(self):
if not hasattr(self, 'model'):
raise RuntimeError("trying to pickle a model that hasn't been built yet")
params = self.params()
# TODO: remove references to the vectorizers and priors from this superclass
state = (self.seq_vec, self.color_vec, [p.get_value() for p in params], self.id)
if hasattr(self, 'prior_emp') and hasattr(self, 'prior_smooth'):
return state + (self.prior_emp, self.prior_smooth)
else:
return state
def __setstate__(self, state):
self.unpickle(state)
def unpickle(self, state, model_class=SimpleLasagneModel):
if isinstance(state, dict) and 'quickpickle' in state and state['quickpickle']:
self.__dict__.update(state)
self.get_options()
return
self.get_options()
# TODO: remove references to the vectorizers from this superclass
if len(state) == 3:
self.seq_vec, self.color_vec, params_state = state
self.id = None
self.train_priors([])
elif len(state) == 4:
self.seq_vec, self.color_vec, params_state, self.id = state
self.train_priors([])
else:
(self.seq_vec, self.color_vec,
params_state, self.id,
self.prior_emp, self.prior_smooth) = state
self._build_model(model_class)
params = self.params()
assert len(params) == len(params_state), '%d != %d' % (len(params), len(params_state))
for p, value in zip(params, params_state):
p.set_value(value)
def get_options(self):
if not hasattr(self, 'options'):
options = config.options()
self.options = argparse.Namespace(**options.__dict__)