Skip to content

Commit dc5064d

Browse files
author
felixfzhao
committed
first commit
1 parent 858e3a3 commit dc5064d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+1115487
-1
lines changed

.idea/Tower_copy.iml

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/profiles_settings.xml

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

+8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

+386
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

LOTN.py

+263
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
4+
from transfer_module.nn_layer import softmax_layer, bi_dynamic_rnn
5+
from transfer_module.config import *
6+
from transfer_module.utils import load_w2v, score_BIO, batch_iter, load_inputs_10
7+
import datetime
8+
import numpy as np
9+
import os
10+
11+
12+
def LOTN(inputs, inputs_s_1, position, sen_len, target, sen_len_tr, attention1, keep_prob1, _id='all'):
13+
cell = tf.contrib.rnn.LSTMCell
14+
inputs = tf.nn.dropout(inputs, keep_prob=keep_prob1)
15+
inputs_s_1 = tf.nn.dropout(inputs_s_1, keep_prob=keep_prob1)
16+
17+
with tf.variable_scope("rnn"):
18+
hiddens_t = bi_dynamic_rnn(cell, inputs_s_1, FLAGS.n_hidden, sen_len, 'sen12')
19+
20+
with tf.variable_scope("rnn1"):
21+
hiddens_s = bi_dynamic_rnn(cell, inputs, FLAGS.n_hidden, sen_len, 'sen13')
22+
23+
hidden_total = tf.concat([hiddens_t, hiddens_s], 2)
24+
25+
outputs = softmax_layer(hidden_total, 4 * FLAGS.n_hidden, FLAGS.max_sentence_len, keep_prob1, FLAGS.l2_reg,
26+
FLAGS.n_class, 'sen22')
27+
outputs_att = softmax_layer(hiddens_s, 2 * FLAGS.n_hidden, FLAGS.max_sentence_len, keep_prob1, FLAGS.l2_reg,
28+
FLAGS.plority, 'sen33')
29+
30+
return outputs, outputs_att
31+
32+
33+
def preprocess(word_id_mapping):
34+
tr_x, tr_sen_len, tr_target_word, tr_tar_len, tr_y, tr_position, tr_attention = load_inputs_10(
35+
FLAGS.train_file_path,
36+
word_id_mapping,
37+
FLAGS.max_sentence_len,
38+
FLAGS.max_target_len
39+
)
40+
41+
np.random.seed(10)
42+
shuffle_indices = np.random.permutation(np.arange(len(tr_x)))
43+
44+
x_shuffled = tr_x[shuffle_indices]
45+
tr_sen_len_shuffled = tr_sen_len[shuffle_indices]
46+
tr_target_word_shuffled = tr_target_word[shuffle_indices]
47+
tr_tar_len_shuffled = tr_tar_len[shuffle_indices]
48+
tr_y_shuffled = tr_y[shuffle_indices]
49+
tr_position_shuffled = tr_position[shuffle_indices]
50+
tr_attention_shuffled = tr_attention[shuffle_indices]
51+
52+
dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(tr_x)))
53+
54+
tr_x_train, tr_x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
55+
tr_sen_len_train, tr_sen_len_dev = tr_sen_len_shuffled[:dev_sample_index], tr_sen_len_shuffled[dev_sample_index:]
56+
tr_target_word_train, tr_target_word_dev = tr_target_word_shuffled[:dev_sample_index], tr_target_word_shuffled[
57+
dev_sample_index:]
58+
tr_tar_len_train, tr_tar_len_dev = tr_tar_len_shuffled[:dev_sample_index], tr_tar_len_shuffled[dev_sample_index:]
59+
tr_y_train, tr_y_dev = tr_y_shuffled[:dev_sample_index], tr_y_shuffled[dev_sample_index:]
60+
tr_position_train, tr_position_dev = tr_position_shuffled[:dev_sample_index], tr_position_shuffled[
61+
dev_sample_index:]
62+
63+
tr_attention_train, tr_attention_dev = tr_attention_shuffled[:dev_sample_index], tr_attention_shuffled[
64+
dev_sample_index:]
65+
66+
print("Train/Dev split: {:d}/{:d}".format(len(tr_x_train), len(tr_x_dev)))
67+
return tr_x_train, tr_sen_len_train, tr_target_word_train, tr_tar_len_train, \
68+
tr_y_train, tr_position_train, tr_attention_train, tr_x_dev, tr_sen_len_dev, tr_target_word_dev, \
69+
tr_tar_len_dev, tr_y_dev, tr_position_dev, tr_attention_dev
70+
71+
72+
def main(_):
73+
word_id_mapping, w2v = load_w2v(FLAGS.embedding_file_path, FLAGS.embedding_dim)
74+
word_embedding = tf.constant(w2v, name='word_embedding')
75+
76+
tr_x_train, tr_sen_len_train, tr_target_word_train, tr_tar_len_train, \
77+
tr_y_train, tr_position_train, tr_attention_train, tr_x_dev, tr_sen_len_dev, tr_target_word_dev, \
78+
tr_tar_len_dev, tr_y_dev, tr_position_dev, tr_attention_dev = preprocess(word_id_mapping)
79+
80+
keep_prob1 = tf.placeholder(tf.float32, name='input_keep_prob1')
81+
with tf.name_scope('inputs'):
82+
x = tf.placeholder(tf.int32, [None, FLAGS.max_sentence_len], name='input_x')
83+
y = tf.placeholder(tf.float32, [None, FLAGS.max_sentence_len, FLAGS.n_class], name='input_y')
84+
sen_len = tf.placeholder(tf.int32, [None], name='input_sen_len')
85+
target_words = tf.placeholder(tf.int32, [None, FLAGS.max_target_len], name='input_target')
86+
tar_len = tf.placeholder(tf.int32, [None], name='input_tar_len')
87+
position = tf.placeholder(tf.int32, [None, FLAGS.max_sentence_len], name='position')
88+
attention1 = tf.placeholder(tf.float32, [None, FLAGS.max_sentence_len, FLAGS.plority],
89+
name='attention_parameter_1')
90+
91+
inputs_s = tf.nn.embedding_lookup(word_embedding, x)
92+
inputs_s_1 = tf.nn.embedding_lookup(word_embedding, x)
93+
94+
position_embeddings = tf.get_variable(
95+
name='position_embedding',
96+
shape=[FLAGS.max_sentence_len, FLAGS.position_embedding_dim],
97+
initializer=tf.random_uniform_initializer(-FLAGS.random_base, FLAGS.random_base),
98+
regularizer=tf.contrib.layers.l2_regularizer(FLAGS.l2_reg)
99+
)
100+
101+
input_position = tf.nn.embedding_lookup(position_embeddings, position)
102+
103+
# target_1 = tf.reduce_mean(tf.nn.embedding_lookup(word_embedding, target_words), 1, keep_dims=True)
104+
# batch_size = tf.shape(inputs_s)[0]
105+
# target_2 = tf.zeros([batch_size, FLAGS.max_sentence_len, FLAGS.embedding_dim]) + target_1
106+
107+
inputs_s = tf.concat([inputs_s, input_position], 2)
108+
109+
# inputs_s = tf.concat([inputs_s, target_2], 2)
110+
111+
target = tf.nn.embedding_lookup(word_embedding, target_words)
112+
prob, prob1 = LOTN(inputs_s, inputs_s_1, position, sen_len, target, tar_len, attention1, keep_prob1, FLAGS.t1)
113+
114+
loss1 = loss_func(y, prob)
115+
loss2 = loss_func(attention1, prob1)
116+
117+
loss = loss1 + FLAGS.Auxiliary_loss * loss2
118+
119+
# acc_num, acc_prob = acc_func(y, prob)
120+
121+
global_step = tf.Variable(0, name='tr_global_step', trainable=False)
122+
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(loss, global_step=global_step)
123+
124+
true_y = tf.argmax(y, 2, name='true_y_1')
125+
pred_y = tf.argmax(prob, 2, name='pred_y_1')
126+
127+
true_attention = tf.argmax(attention1, 2, name='true_attention_1')
128+
pred_attention = tf.argmax(prob1, 2, name='pred_attention_1')
129+
130+
config = tf.ConfigProto(allow_soft_placement=True)
131+
config.gpu_options.allow_growth = True
132+
with tf.Session(config=config) as sess:
133+
init = tf.global_variables_initializer()
134+
sess.run(init)
135+
136+
if FLAGS.pre_trained == "sentence_tranfer":
137+
pre_trained_variables = [v for v in tf.global_variables() if
138+
v.name.startswith("rnn/sen12") and "Adam" not in v.name]
139+
print(pre_trained_variables)
140+
saver = tf.train.Saver(pre_trained_variables)
141+
ckpt = tf.train.get_checkpoint_state(FLAGS.pre_trained_path)
142+
# ckpt = tf.train.get_checkpoint_state('data/yelp/checkpoint8/')
143+
saver.restore(sess, ckpt.model_checkpoint_path)
144+
145+
def train_step(i, x_f, sen_len_f, target, tl, yi, x_poisition, x_attention, kp1):
146+
feed_dict = {
147+
x: x_f,
148+
y: yi,
149+
sen_len: sen_len_f,
150+
target_words: target,
151+
tar_len: tl,
152+
position: x_poisition,
153+
attention1: x_attention,
154+
keep_prob1: kp1
155+
}
156+
step, _, losses = sess.run([global_step, optimizer, loss], feed_dict)
157+
time_str = datetime.datetime.now().isoformat()
158+
print("{}: Iter {}, step {}, loss {:g}".format(time_str, i, step, losses))
159+
160+
def dev_step(te_x_f, te_sen_len_f, te_target, te_tl, te_yi, te_x_poisition, te_x_attention):
161+
feed_dict = {
162+
x: te_x_f,
163+
y: te_yi,
164+
sen_len: te_sen_len_f,
165+
target_words: te_target,
166+
tar_len: te_tl,
167+
position: te_x_poisition,
168+
attention1: te_x_attention,
169+
keep_prob1: 1.0
170+
}
171+
172+
tf_true, tf_pred, tf_true_attention, tf_pred_attention, prob1, _loss = sess.run(
173+
[true_y, pred_y, true_attention, pred_attention, prob, loss], feed_dict)
174+
cost = 0
175+
pre_label, att_pre_label, true_label, att_true_label = [], [], [], []
176+
for logit, position1, att_logit, att_position1, length in zip(tf_pred, tf_true, tf_pred_attention,
177+
tf_true_attention, tr_sen_len_dev):
178+
logit = logit[:length]
179+
tr_position = position1[:length]
180+
181+
att_logit = att_logit[:length]
182+
tr_att_position = att_position1[:length]
183+
184+
cost += _loss * length
185+
pre_label.append(logit)
186+
true_label.append(tr_position)
187+
188+
att_pre_label.append(att_logit)
189+
att_true_label.append(tr_att_position)
190+
191+
return pre_label, att_pre_label, true_label, att_true_label, cost
192+
193+
checkpoint_dir = os.path.abspath(FLAGS.saver_checkpoint)
194+
checkpoint_prefix = os.path.join(checkpoint_dir, "model")
195+
if not os.path.exists(checkpoint_dir):
196+
os.makedirs(checkpoint_dir)
197+
saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
198+
199+
max_f1 = 0
200+
max_recall = 0
201+
max_precision = 0
202+
last_improvement = 0
203+
require_improvement_iterations = 5
204+
max_label, max_att_label = None, None
205+
for i in range(FLAGS.n_iter):
206+
batches_train = batch_iter(
207+
list(zip(tr_x_train, tr_sen_len_train, tr_target_word_train, tr_tar_len_train, tr_y_train,
208+
tr_position_train, tr_attention_train)), FLAGS.batch_size, 1, True)
209+
for batch in batches_train:
210+
x_batch, sen_len_batch, target_batch, tar_len_batch, y_batch, position_batch, attention_batch = zip(
211+
*batch)
212+
train_step(i, x_batch, sen_len_batch, target_batch, tar_len_batch, y_batch, position_batch,
213+
attention_batch,
214+
FLAGS.keep_prob1)
215+
216+
batches_test = batch_iter(
217+
list(zip(tr_x_dev, tr_sen_len_dev, tr_target_word_dev, tr_tar_len_dev, tr_y_dev, tr_position_dev,
218+
tr_attention_dev)), 500,
219+
1, False)
220+
label_pp, att_label_pp, label_tt, att_label_tt = [], [], [], []
221+
cost1 = 0
222+
for batch_ in batches_test:
223+
te_x_batch, te_sen_len_batch, te_target_batch, te_tar_len_batch, te_y_batch, te_position_batch, te_attention_batch = zip(
224+
*batch_)
225+
label_p, att_label_p, label_t, att_label_t, _loss = dev_step(te_x_batch, te_sen_len_batch,
226+
te_target_batch, te_tar_len_batch,
227+
te_y_batch, te_position_batch,
228+
te_attention_batch)
229+
label_pp += label_p
230+
label_tt += label_t
231+
232+
att_label_pp += att_label_p
233+
att_label_tt += att_label_t
234+
235+
cost1 += _loss
236+
print("\nEvaluation:")
237+
238+
precision, recall, f1 = score_BIO(label_pp, label_tt)
239+
current_step = tf.train.global_step(sess, global_step)
240+
print("Iter {}: step {}, loss {}, precision {:g}, recall {:g}, f1 {:g}".format(
241+
i, current_step, cost1, precision, recall, f1))
242+
243+
if f1 > max_f1:
244+
max_f1 = f1
245+
max_precision = precision
246+
max_recall = recall
247+
max_label = label_pp
248+
max_att_label = att_label_pp
249+
last_improvement = i
250+
path = saver.save(sess, checkpoint_prefix, global_step=current_step)
251+
print("Saved model checkpoint to {}\n".format(path))
252+
print("topf1 {:g}, precision {:g}, recall {:g}".format(max_f1, max_precision, max_recall))
253+
print("\n")
254+
# if i - last_improvement > require_improvement_iterations:
255+
# print('No improvement found in a while, stop running')
256+
# break
257+
fp = open(FLAGS.prob_file, 'w')
258+
for ws in max_label:
259+
fp.write(' '.join([str(w) for w in ws]) + '\n')
260+
261+
262+
if __name__ == '__main__':
263+
tf.app.run()

0 commit comments

Comments
 (0)