Skip to content

Commit

Permalink
towards continuous optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
Pasquale Minervini committed Dec 11, 2017
1 parent ed7a2b1 commit fd9d942
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 18 deletions.
90 changes: 74 additions & 16 deletions bin/nli-csearch-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# Running:
# $ python3 ./bin/nli-dsearch-cli.py --has-bos --has-unk --restore models/snli/dam_1/dam_1

import os
import sys

import json
Expand All @@ -27,7 +26,7 @@

import logging

np.set_printoptions(threshold=np.nan)
# np.set_printoptions(threshold=np.nan)

logger = logging.getLogger(__name__)
rs = np.random.RandomState(0)
Expand All @@ -51,6 +50,26 @@
lm_loss = lm_cost = None


def log_perplexity(sentences, sizes):
assert sentences.shape[0] == sizes.shape[0]
_batch_size = sentences.shape[0]
x = np.zeros(shape=(_batch_size, 1))
y = np.zeros(shape=(_batch_size, 1))
_sentences, _sizes = sentences[:, 1:], sizes[:] - 1
state = session.run(lm_cell.zero_state(_batch_size, tf.float32))
loss_values = []
for j in range(_sizes.max() - 1):
x[:, 0] = _sentences[:, j]
y[:, 0] = _sentences[:, j + 1]
feed = {lm_input_data_ph: x, lm_targets_ph: y, lm_initial_state: state}
loss_value, state = session.run([lm_loss, lm_final_state], feed_dict=feed)
loss_values += [loss_value]
loss_values = np.array(loss_values).transpose()
__sizes = _sizes - 2
res = np.array([np.sum(loss_values[_i, :__sizes[_i]]) for _i in range(loss_values.shape[0])])
return res


def main(argv):
logger.info('Command line: {}'.format(' '.join(arg for arg in argv)))

Expand All @@ -66,7 +85,8 @@ def fmt(prog):
argparser.add_argument('--embedding-size', action='store', type=int, default=300)
argparser.add_argument('--representation-size', action='store', type=int, default=200)

argparser.add_argument('--batch-size', action='store', type=int, default=32)
argparser.add_argument('--batch-size', '-b', action='store', type=int, default=32)
argparser.add_argument('--seq-length', action='store', type=int, default=5)

argparser.add_argument('--seed', action='store', type=int, default=0)

Expand Down Expand Up @@ -127,12 +147,7 @@ def fmt(prog):
rnn_size = config['rnn_size']
num_layers = config['num_layers']

label_to_index = {
'entailment': entailment_idx,
'neutral': neutral_idx,
'contradiction': contradiction_idx,
}

label_to_index = {'entailment': entailment_idx, 'neutral': neutral_idx, 'contradiction': contradiction_idx}
max_len = None

args = dict(
Expand Down Expand Up @@ -199,39 +214,82 @@ def fmt(prog):
global lm_input_data_ph, lm_targets_ph, lm_initial_state
lm_input_data_ph = tf.placeholder(tf.int32, [None, seq_length], name='input_data')
lm_targets_ph = tf.placeholder(tf.int32, [None, seq_length], name='targets')
lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32, )

lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32)

with tf.variable_scope('rnnlm'):
lm_W = tf.get_variable(name='W', shape=[rnn_size, vocab_size],
lm_W = tf.get_variable(name='W',
shape=[rnn_size, vocab_size],
initializer=tf.contrib.layers.xavier_initializer())

lm_b = tf.get_variable(name='b', shape=[vocab_size],
lm_b = tf.get_variable(name='b',
shape=[vocab_size],
initializer=tf.zeros_initializer())

lm_emb_lookup = tf.nn.embedding_lookup(embedding_layer, lm_input_data_ph)
lm_emb_projection = tf.contrib.layers.fully_connected(inputs=lm_emb_lookup, num_outputs=rnn_size,

lm_emb_projection = tf.contrib.layers.fully_connected(inputs=lm_emb_lookup,
num_outputs=rnn_size,
weights_initializer=tf.contrib.layers.xavier_initializer(),
biases_initializer=tf.zeros_initializer())

lm_inputs = tf.split(lm_emb_projection, seq_length, 1)
lm_inputs = [tf.squeeze(input_, [1]) for input_ in lm_inputs]

lm_outputs, lm_last_state = legacy_seq2seq.rnn_decoder(decoder_inputs=lm_inputs, initial_state=lm_initial_state,
cell=lm_cell, loop_function=None, scope='rnnlm')
lm_outputs, lm_last_state = legacy_seq2seq.rnn_decoder(decoder_inputs=lm_inputs,
initial_state=lm_initial_state,
cell=lm_cell,
loop_function=None,
scope='rnnlm')

lm_output = tf.reshape(tf.concat(lm_outputs, 1), [-1, rnn_size])

lm_logits = tf.matmul(lm_output, lm_W) + lm_b
lm_probabilities = tf.nn.softmax(lm_logits)

global lm_loss, lm_cost, lm_final_state
lm_loss = legacy_seq2seq.sequence_loss_by_example(logits=[lm_logits], targets=[tf.reshape(lm_targets_ph, [-1])],
lm_loss = legacy_seq2seq.sequence_loss_by_example(logits=[lm_logits],
targets=[tf.reshape(lm_targets_ph, [-1])],
weights=[tf.ones([lm_batch_size * seq_length])])
lm_cost = tf.reduce_sum(lm_loss) / lm_batch_size / seq_length
lm_final_state = lm_last_state

discriminator_vars = tfutil.get_variables_in_scope(discriminator_scope_name)
lm_vars = tfutil.get_variables_in_scope(lm_scope_name)

predictions_int = tf.cast(predictions, tf.int32)

saver = tf.train.Saver(discriminator_vars, max_to_keep=1)
lm_saver = tf.train.Saver(lm_vars, max_to_keep=1)

session_config = tf.ConfigProto()
session_config.gpu_options.allow_growth = True

global session
with tf.Session(config=session_config) as session:
logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters()))

saver.restore(session, restore_path)

lm_ckpt = tf.train.get_checkpoint_state(lm_path)
lm_saver.restore(session, lm_ckpt.model_checkpoint_path)

embedding_layer_value = session.run(embedding_layer)
print(embedding_layer_value.shape)

text = ['The', 'girl', 'runs', 'on', 'the', 'plane', '.']

sentences = np.array([[token_to_index[token] for token in text]])
sizes = np.array([len(text)])

print(log_perplexity(sentences, sizes))

feed = {
sentence1_ph: sentences,
sentence1_len_ph: sizes
}
tmp = session.run(sentence1_embedding, feed_dict=feed)
print(tmp.shape)

if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
Expand Down
4 changes: 2 additions & 2 deletions bin/nli-dsearch-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def relu(x):
return np.maximum(x, 0)


def log_perplexity(sentences, sizes):
def log_perplexity(sentences, sizes):
assert sentences.shape[0] == sizes.shape[0]
_batch_size = sentences.shape[0]
x = np.zeros(shape=(_batch_size, 1))
Expand Down Expand Up @@ -445,7 +445,7 @@ def fmt(prog):
global lm_input_data_ph, lm_targets_ph, lm_initial_state
lm_input_data_ph = tf.placeholder(tf.int32, [None, seq_length], name='input_data')
lm_targets_ph = tf.placeholder(tf.int32, [None, seq_length], name='targets')
lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32, )
lm_initial_state = lm_cell.zero_state(lm_batch_size, tf.float32)

with tf.variable_scope('rnnlm'):
lm_W = tf.get_variable(name='W', shape=[rnn_size, vocab_size],
Expand Down

0 comments on commit fd9d942

Please sign in to comment.