|
| 1 | +#!/usr/bin/env python |
| 2 | +# encoding: utf-8 |
| 3 | + |
| 4 | +from transfer_module.nn_layer import softmax_layer, bi_dynamic_rnn |
| 5 | +from transfer_module.config import * |
| 6 | +from transfer_module.utils import load_w2v, score_BIO, batch_iter, load_inputs_10 |
| 7 | +import datetime |
| 8 | +import numpy as np |
| 9 | +import os |
| 10 | + |
| 11 | + |
| 12 | +def LOTN(inputs, inputs_s_1, position, sen_len, target, sen_len_tr, attention1, keep_prob1, _id='all'): |
| 13 | + cell = tf.contrib.rnn.LSTMCell |
| 14 | + inputs = tf.nn.dropout(inputs, keep_prob=keep_prob1) |
| 15 | + inputs_s_1 = tf.nn.dropout(inputs_s_1, keep_prob=keep_prob1) |
| 16 | + |
| 17 | + with tf.variable_scope("rnn"): |
| 18 | + hiddens_t = bi_dynamic_rnn(cell, inputs_s_1, FLAGS.n_hidden, sen_len, 'sen12') |
| 19 | + |
| 20 | + with tf.variable_scope("rnn1"): |
| 21 | + hiddens_s = bi_dynamic_rnn(cell, inputs, FLAGS.n_hidden, sen_len, 'sen13') |
| 22 | + |
| 23 | + hidden_total = tf.concat([hiddens_t, hiddens_s], 2) |
| 24 | + |
| 25 | + outputs = softmax_layer(hidden_total, 4 * FLAGS.n_hidden, FLAGS.max_sentence_len, keep_prob1, FLAGS.l2_reg, |
| 26 | + FLAGS.n_class, 'sen22') |
| 27 | + outputs_att = softmax_layer(hiddens_s, 2 * FLAGS.n_hidden, FLAGS.max_sentence_len, keep_prob1, FLAGS.l2_reg, |
| 28 | + FLAGS.plority, 'sen33') |
| 29 | + |
| 30 | + return outputs, outputs_att |
| 31 | + |
| 32 | + |
| 33 | +def preprocess(word_id_mapping): |
| 34 | + tr_x, tr_sen_len, tr_target_word, tr_tar_len, tr_y, tr_position, tr_attention = load_inputs_10( |
| 35 | + FLAGS.train_file_path, |
| 36 | + word_id_mapping, |
| 37 | + FLAGS.max_sentence_len, |
| 38 | + FLAGS.max_target_len |
| 39 | + ) |
| 40 | + |
| 41 | + np.random.seed(10) |
| 42 | + shuffle_indices = np.random.permutation(np.arange(len(tr_x))) |
| 43 | + |
| 44 | + x_shuffled = tr_x[shuffle_indices] |
| 45 | + tr_sen_len_shuffled = tr_sen_len[shuffle_indices] |
| 46 | + tr_target_word_shuffled = tr_target_word[shuffle_indices] |
| 47 | + tr_tar_len_shuffled = tr_tar_len[shuffle_indices] |
| 48 | + tr_y_shuffled = tr_y[shuffle_indices] |
| 49 | + tr_position_shuffled = tr_position[shuffle_indices] |
| 50 | + tr_attention_shuffled = tr_attention[shuffle_indices] |
| 51 | + |
| 52 | + dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(tr_x))) |
| 53 | + |
| 54 | + tr_x_train, tr_x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:] |
| 55 | + tr_sen_len_train, tr_sen_len_dev = tr_sen_len_shuffled[:dev_sample_index], tr_sen_len_shuffled[dev_sample_index:] |
| 56 | + tr_target_word_train, tr_target_word_dev = tr_target_word_shuffled[:dev_sample_index], tr_target_word_shuffled[ |
| 57 | + dev_sample_index:] |
| 58 | + tr_tar_len_train, tr_tar_len_dev = tr_tar_len_shuffled[:dev_sample_index], tr_tar_len_shuffled[dev_sample_index:] |
| 59 | + tr_y_train, tr_y_dev = tr_y_shuffled[:dev_sample_index], tr_y_shuffled[dev_sample_index:] |
| 60 | + tr_position_train, tr_position_dev = tr_position_shuffled[:dev_sample_index], tr_position_shuffled[ |
| 61 | + dev_sample_index:] |
| 62 | + |
| 63 | + tr_attention_train, tr_attention_dev = tr_attention_shuffled[:dev_sample_index], tr_attention_shuffled[ |
| 64 | + dev_sample_index:] |
| 65 | + |
| 66 | + print("Train/Dev split: {:d}/{:d}".format(len(tr_x_train), len(tr_x_dev))) |
| 67 | + return tr_x_train, tr_sen_len_train, tr_target_word_train, tr_tar_len_train, \ |
| 68 | + tr_y_train, tr_position_train, tr_attention_train, tr_x_dev, tr_sen_len_dev, tr_target_word_dev, \ |
| 69 | + tr_tar_len_dev, tr_y_dev, tr_position_dev, tr_attention_dev |
| 70 | + |
| 71 | + |
| 72 | +def main(_): |
| 73 | + word_id_mapping, w2v = load_w2v(FLAGS.embedding_file_path, FLAGS.embedding_dim) |
| 74 | + word_embedding = tf.constant(w2v, name='word_embedding') |
| 75 | + |
| 76 | + tr_x_train, tr_sen_len_train, tr_target_word_train, tr_tar_len_train, \ |
| 77 | + tr_y_train, tr_position_train, tr_attention_train, tr_x_dev, tr_sen_len_dev, tr_target_word_dev, \ |
| 78 | + tr_tar_len_dev, tr_y_dev, tr_position_dev, tr_attention_dev = preprocess(word_id_mapping) |
| 79 | + |
| 80 | + keep_prob1 = tf.placeholder(tf.float32, name='input_keep_prob1') |
| 81 | + with tf.name_scope('inputs'): |
| 82 | + x = tf.placeholder(tf.int32, [None, FLAGS.max_sentence_len], name='input_x') |
| 83 | + y = tf.placeholder(tf.float32, [None, FLAGS.max_sentence_len, FLAGS.n_class], name='input_y') |
| 84 | + sen_len = tf.placeholder(tf.int32, [None], name='input_sen_len') |
| 85 | + target_words = tf.placeholder(tf.int32, [None, FLAGS.max_target_len], name='input_target') |
| 86 | + tar_len = tf.placeholder(tf.int32, [None], name='input_tar_len') |
| 87 | + position = tf.placeholder(tf.int32, [None, FLAGS.max_sentence_len], name='position') |
| 88 | + attention1 = tf.placeholder(tf.float32, [None, FLAGS.max_sentence_len, FLAGS.plority], |
| 89 | + name='attention_parameter_1') |
| 90 | + |
| 91 | + inputs_s = tf.nn.embedding_lookup(word_embedding, x) |
| 92 | + inputs_s_1 = tf.nn.embedding_lookup(word_embedding, x) |
| 93 | + |
| 94 | + position_embeddings = tf.get_variable( |
| 95 | + name='position_embedding', |
| 96 | + shape=[FLAGS.max_sentence_len, FLAGS.position_embedding_dim], |
| 97 | + initializer=tf.random_uniform_initializer(-FLAGS.random_base, FLAGS.random_base), |
| 98 | + regularizer=tf.contrib.layers.l2_regularizer(FLAGS.l2_reg) |
| 99 | + ) |
| 100 | + |
| 101 | + input_position = tf.nn.embedding_lookup(position_embeddings, position) |
| 102 | + |
| 103 | + # target_1 = tf.reduce_mean(tf.nn.embedding_lookup(word_embedding, target_words), 1, keep_dims=True) |
| 104 | + # batch_size = tf.shape(inputs_s)[0] |
| 105 | + # target_2 = tf.zeros([batch_size, FLAGS.max_sentence_len, FLAGS.embedding_dim]) + target_1 |
| 106 | + |
| 107 | + inputs_s = tf.concat([inputs_s, input_position], 2) |
| 108 | + |
| 109 | + # inputs_s = tf.concat([inputs_s, target_2], 2) |
| 110 | + |
| 111 | + target = tf.nn.embedding_lookup(word_embedding, target_words) |
| 112 | + prob, prob1 = LOTN(inputs_s, inputs_s_1, position, sen_len, target, tar_len, attention1, keep_prob1, FLAGS.t1) |
| 113 | + |
| 114 | + loss1 = loss_func(y, prob) |
| 115 | + loss2 = loss_func(attention1, prob1) |
| 116 | + |
| 117 | + loss = loss1 + FLAGS.Auxiliary_loss * loss2 |
| 118 | + |
| 119 | + # acc_num, acc_prob = acc_func(y, prob) |
| 120 | + |
| 121 | + global_step = tf.Variable(0, name='tr_global_step', trainable=False) |
| 122 | + optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(loss, global_step=global_step) |
| 123 | + |
| 124 | + true_y = tf.argmax(y, 2, name='true_y_1') |
| 125 | + pred_y = tf.argmax(prob, 2, name='pred_y_1') |
| 126 | + |
| 127 | + true_attention = tf.argmax(attention1, 2, name='true_attention_1') |
| 128 | + pred_attention = tf.argmax(prob1, 2, name='pred_attention_1') |
| 129 | + |
| 130 | + config = tf.ConfigProto(allow_soft_placement=True) |
| 131 | + config.gpu_options.allow_growth = True |
| 132 | + with tf.Session(config=config) as sess: |
| 133 | + init = tf.global_variables_initializer() |
| 134 | + sess.run(init) |
| 135 | + |
| 136 | + if FLAGS.pre_trained == "sentence_tranfer": |
| 137 | + pre_trained_variables = [v for v in tf.global_variables() if |
| 138 | + v.name.startswith("rnn/sen12") and "Adam" not in v.name] |
| 139 | + print(pre_trained_variables) |
| 140 | + saver = tf.train.Saver(pre_trained_variables) |
| 141 | + ckpt = tf.train.get_checkpoint_state(FLAGS.pre_trained_path) |
| 142 | + # ckpt = tf.train.get_checkpoint_state('data/yelp/checkpoint8/') |
| 143 | + saver.restore(sess, ckpt.model_checkpoint_path) |
| 144 | + |
| 145 | + def train_step(i, x_f, sen_len_f, target, tl, yi, x_poisition, x_attention, kp1): |
| 146 | + feed_dict = { |
| 147 | + x: x_f, |
| 148 | + y: yi, |
| 149 | + sen_len: sen_len_f, |
| 150 | + target_words: target, |
| 151 | + tar_len: tl, |
| 152 | + position: x_poisition, |
| 153 | + attention1: x_attention, |
| 154 | + keep_prob1: kp1 |
| 155 | + } |
| 156 | + step, _, losses = sess.run([global_step, optimizer, loss], feed_dict) |
| 157 | + time_str = datetime.datetime.now().isoformat() |
| 158 | + print("{}: Iter {}, step {}, loss {:g}".format(time_str, i, step, losses)) |
| 159 | + |
| 160 | + def dev_step(te_x_f, te_sen_len_f, te_target, te_tl, te_yi, te_x_poisition, te_x_attention): |
| 161 | + feed_dict = { |
| 162 | + x: te_x_f, |
| 163 | + y: te_yi, |
| 164 | + sen_len: te_sen_len_f, |
| 165 | + target_words: te_target, |
| 166 | + tar_len: te_tl, |
| 167 | + position: te_x_poisition, |
| 168 | + attention1: te_x_attention, |
| 169 | + keep_prob1: 1.0 |
| 170 | + } |
| 171 | + |
| 172 | + tf_true, tf_pred, tf_true_attention, tf_pred_attention, prob1, _loss = sess.run( |
| 173 | + [true_y, pred_y, true_attention, pred_attention, prob, loss], feed_dict) |
| 174 | + cost = 0 |
| 175 | + pre_label, att_pre_label, true_label, att_true_label = [], [], [], [] |
| 176 | + for logit, position1, att_logit, att_position1, length in zip(tf_pred, tf_true, tf_pred_attention, |
| 177 | + tf_true_attention, tr_sen_len_dev): |
| 178 | + logit = logit[:length] |
| 179 | + tr_position = position1[:length] |
| 180 | + |
| 181 | + att_logit = att_logit[:length] |
| 182 | + tr_att_position = att_position1[:length] |
| 183 | + |
| 184 | + cost += _loss * length |
| 185 | + pre_label.append(logit) |
| 186 | + true_label.append(tr_position) |
| 187 | + |
| 188 | + att_pre_label.append(att_logit) |
| 189 | + att_true_label.append(tr_att_position) |
| 190 | + |
| 191 | + return pre_label, att_pre_label, true_label, att_true_label, cost |
| 192 | + |
| 193 | + checkpoint_dir = os.path.abspath(FLAGS.saver_checkpoint) |
| 194 | + checkpoint_prefix = os.path.join(checkpoint_dir, "model") |
| 195 | + if not os.path.exists(checkpoint_dir): |
| 196 | + os.makedirs(checkpoint_dir) |
| 197 | + saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) |
| 198 | + |
| 199 | + max_f1 = 0 |
| 200 | + max_recall = 0 |
| 201 | + max_precision = 0 |
| 202 | + last_improvement = 0 |
| 203 | + require_improvement_iterations = 5 |
| 204 | + max_label, max_att_label = None, None |
| 205 | + for i in range(FLAGS.n_iter): |
| 206 | + batches_train = batch_iter( |
| 207 | + list(zip(tr_x_train, tr_sen_len_train, tr_target_word_train, tr_tar_len_train, tr_y_train, |
| 208 | + tr_position_train, tr_attention_train)), FLAGS.batch_size, 1, True) |
| 209 | + for batch in batches_train: |
| 210 | + x_batch, sen_len_batch, target_batch, tar_len_batch, y_batch, position_batch, attention_batch = zip( |
| 211 | + *batch) |
| 212 | + train_step(i, x_batch, sen_len_batch, target_batch, tar_len_batch, y_batch, position_batch, |
| 213 | + attention_batch, |
| 214 | + FLAGS.keep_prob1) |
| 215 | + |
| 216 | + batches_test = batch_iter( |
| 217 | + list(zip(tr_x_dev, tr_sen_len_dev, tr_target_word_dev, tr_tar_len_dev, tr_y_dev, tr_position_dev, |
| 218 | + tr_attention_dev)), 500, |
| 219 | + 1, False) |
| 220 | + label_pp, att_label_pp, label_tt, att_label_tt = [], [], [], [] |
| 221 | + cost1 = 0 |
| 222 | + for batch_ in batches_test: |
| 223 | + te_x_batch, te_sen_len_batch, te_target_batch, te_tar_len_batch, te_y_batch, te_position_batch, te_attention_batch = zip( |
| 224 | + *batch_) |
| 225 | + label_p, att_label_p, label_t, att_label_t, _loss = dev_step(te_x_batch, te_sen_len_batch, |
| 226 | + te_target_batch, te_tar_len_batch, |
| 227 | + te_y_batch, te_position_batch, |
| 228 | + te_attention_batch) |
| 229 | + label_pp += label_p |
| 230 | + label_tt += label_t |
| 231 | + |
| 232 | + att_label_pp += att_label_p |
| 233 | + att_label_tt += att_label_t |
| 234 | + |
| 235 | + cost1 += _loss |
| 236 | + print("\nEvaluation:") |
| 237 | + |
| 238 | + precision, recall, f1 = score_BIO(label_pp, label_tt) |
| 239 | + current_step = tf.train.global_step(sess, global_step) |
| 240 | + print("Iter {}: step {}, loss {}, precision {:g}, recall {:g}, f1 {:g}".format( |
| 241 | + i, current_step, cost1, precision, recall, f1)) |
| 242 | + |
| 243 | + if f1 > max_f1: |
| 244 | + max_f1 = f1 |
| 245 | + max_precision = precision |
| 246 | + max_recall = recall |
| 247 | + max_label = label_pp |
| 248 | + max_att_label = att_label_pp |
| 249 | + last_improvement = i |
| 250 | + path = saver.save(sess, checkpoint_prefix, global_step=current_step) |
| 251 | + print("Saved model checkpoint to {}\n".format(path)) |
| 252 | + print("topf1 {:g}, precision {:g}, recall {:g}".format(max_f1, max_precision, max_recall)) |
| 253 | + print("\n") |
| 254 | + # if i - last_improvement > require_improvement_iterations: |
| 255 | + # print('No improvement found in a while, stop running') |
| 256 | + # break |
| 257 | + fp = open(FLAGS.prob_file, 'w') |
| 258 | + for ws in max_label: |
| 259 | + fp.write(' '.join([str(w) for w in ws]) + '\n') |
| 260 | + |
| 261 | + |
| 262 | +if __name__ == '__main__': |
| 263 | + tf.app.run() |
0 commit comments