-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathtrain.py
55 lines (48 loc) · 2.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import absolute_import
from __future__ import print_function
# __future__ is a module that supports code portability between different versions of Python.
import numpy as np
import os
import nn_utils.network_utils as network_utils
import config.nn_config as nn_config
config = nn_config.get_neural_net_configuration()
inputFile = config['model_file']
cur_iter = 0
model_basename = config['model_basename']
model_filename = model_basename + str(cur_iter)
# Load up the training data
print('Loading training data')
# X_train is a tensor of size (num_train_examples, num_timesteps, num_frequency_dims)
# y_train is a tensor of size (num_train_examples, num_timesteps, num_frequency_dims)
X_train = np.load(inputFile + '_x.npy')
y_train = np.load(inputFile + '_y.npy')
print('Finished loading training data')
# Figure out how many frequencies we have in the data
freq_space_dims = X_train.shape[2] #88200
print('freq_space_dims=',freq_space_dims)
hidden_dims = config['hidden_dimension_size']
print('Using Mean Absolute Error')
# Creates a lstm network
model = network_utils.create_lstm_network(num_frequency_dimensions=freq_space_dims, num_hidden_dimensions=hidden_dims) #hidden_dims=1024
# Load existing weights if available
if os.path.isfile(model_filename):
model.load_weights(model_filename)
num_iters = 1 # Number of iterations for training
epochs_per_iter = 1 # Number of iterations before we save our model
batch_size = 1 # Number of training examples pushed to the GPU per batch.
# Larger batch sizes require more memory, but training will be faster
print('Starting training!')
while cur_iter < num_iters:
print('Iteration: ' + str(cur_iter))
# We set cross-validation to 0,
# as cross-validation will be on different datasets
# if we reload our model between runs
# The moral way to handle this is to manually split
# your data into two sets and run cross-validation after
# you've trained the model for some number of epochs
# Iterate over the training data in batches
history = model.fit(X_train, y_train, batch_size=batch_size, nb_epoch=epochs_per_iter, verbose=1,
validation_split=0.0)
cur_iter += epochs_per_iter
print('Training complete!')
model.save_weights(model_basename + str(cur_iter))