-
Notifications
You must be signed in to change notification settings - Fork 87
/
ssq4all.py
120 lines (110 loc) · 5.82 KB
/
ssq4all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
# file: main.py
# author: JinTian
# time: 11/03/2017 9:53 AM
# Copyright 2017 JinTian. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
import os
import numpy as np
import tensorflow as tf
from poems.model import rnn_model
from poems.resnet import *
from poems.poems import process_poems, generate_batch
from ssq_data import *
# for Windows10:OSError: raw write() returned invalid length 96 (should have been between 0 and 48)
# import win_unicode_console
# win_unicode_console.enable()
tf.app.flags.DEFINE_integer('batch_size', 2214, 'batch size.')
tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'learning rate.')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model4all'), 'model save path.')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
tf.app.flags.DEFINE_integer('epochs', 500000, 'train how many epochs.')
FLAGS = tf.app.flags.FLAGS
def run_training():
if not os.path.exists(FLAGS.model_dir):
os.makedirs(FLAGS.model_dir)
#
# poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
# batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
ssqdata=get_exl_data(random_order=False,use_resnet=True)
# print(ssqdata[len(ssqdata)-1])
batches_inputs=ssqdata[0:(len(ssqdata)-1)]
ssqdata=get_exl_data(random_order=False,use_resnet=False)
batches_outputs = ssqdata[1:(len(ssqdata))]
FLAGS.batch_size=len(batches_inputs)
# print(np.shape(batches_outputs))
# data=batches_outputs[1:7]
# print(len(data))
del ssqdata
input_data = tf.placeholder(tf.float32, [FLAGS.batch_size, 1,7,1])
logits = inference(input_data, 2, reuse=False,output_num=128)
# print(tf.shape(input_data))
output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
end_points = rnn_model(model='lstm', input_data=logits, output_data=output_targets, vocab_size=33+16,output_num=7,
rnn_size=128, num_layers=7, batch_size=FLAGS.batch_size, learning_rate=FLAGS.learning_rate)
# end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
# vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
# sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
# sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
sess.run(init_op)
start_epoch = 0
checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
if checkpoint:
saver.restore(sess, checkpoint)
print("## restore from the checkpoint {0}".format(checkpoint))
start_epoch += int(checkpoint.split('-')[-1])
print('## start training...')
try:
for epoch in range(start_epoch, FLAGS.epochs):
n = 0
# n_chunk = len(poems_vector) // FLAGS.batch_size
# n_chunk = len(batches_inputs) // FLAGS.batch_size
n_chunk=math.ceil(len(batches_inputs) / FLAGS.batch_size)
for batch in range(n_chunk):
left=(batch+1)*FLAGS.batch_size-len(batches_inputs)
if left<0:
inputdata=batches_inputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
outputdata=batches_outputs[(batch*FLAGS.batch_size):((batch+1)*FLAGS.batch_size)]
else:
# temp=batches_inputs[batch*FLAGS.batch_size:len(batches_inputs) ]
# temp.extend(batches_inputs[0:left])
inputdata=batches_inputs[len(batches_inputs)-FLAGS.batch_size:len(batches_inputs)]
# temp=batches_outputs[batch*FLAGS.batch_size:len(batches_inputs) ]
# temp.extend(batches_outputs[0:left])
outputdata=batches_outputs[len(batches_outputs)-FLAGS.batch_size:len(batches_outputs)]
# print(len(inputdata))
loss, _, _ = sess.run([
end_points['total_loss'],
end_points['last_state'],
end_points['train_op']
], feed_dict={input_data: inputdata, output_targets: outputdata})
# ], feed_dict={input_data: batches_inputs, output_targets: batches_outputs})
n += 1
print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
if epoch % 50000 == 0:
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
except KeyboardInterrupt:
print('## Interrupt manually, try saving checkpoint for now...')
finally:
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
def main(_):
run_training()
if __name__ == '__main__':
tf.app.run()