Skip to content

Commit

Permalink
updating schematic-memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Pasquale Minervini committed Dec 12, 2017
1 parent fd9d942 commit 9153339
Show file tree
Hide file tree
Showing 5 changed files with 53,043 additions and 41,125 deletions.
33 changes: 24 additions & 9 deletions bin/nli-csearch-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from inferbeddings.nli import FeedForwardDAMS
from inferbeddings.nli import ESIMv1

from inferbeddings.nli.regularizers.adversarial import AdversarialSets

import logging

# np.set_printoptions(threshold=np.nan)
Expand Down Expand Up @@ -262,9 +264,24 @@ def fmt(prog):
saver = tf.train.Saver(discriminator_vars, max_to_keep=1)
lm_saver = tf.train.Saver(lm_vars, max_to_keep=1)

adversary = AdversarialSets(model_class=model_class,
model_kwargs=model_kwargs,
embedding_size=embedding_size,
scope_name='adversary',
batch_size=1,
sequence_length=10,
entailment_idx=entailment_idx,
contradiction_idx=contradiction_idx,
neutral_idx=neutral_idx)

a_loss, a_sequence_set = adversary.rule6_loss()

session_config = tf.ConfigProto()
session_config.gpu_options.allow_growth = True

text = ['The', 'girl', 'runs', 'on', 'the', 'plane', '.']
sentence_ids = [token_to_index[token] for token in text]

global session
with tf.Session(config=session_config) as session:
logger.info('Total Parameters: {}'.format(tfutil.count_trainable_parameters()))
Expand All @@ -275,21 +292,19 @@ def fmt(prog):
lm_saver.restore(session, lm_ckpt.model_checkpoint_path)

embedding_layer_value = session.run(embedding_layer)
print(embedding_layer_value.shape)
assert embedding_layer_value.shape == (vocab_size, embedding_size)

text = ['The', 'girl', 'runs', 'on', 'the', 'plane', '.']

sentences = np.array([[token_to_index[token] for token in text]])
sizes = np.array([len(text)])

print(log_perplexity(sentences, sizes))
sentences, sizes = np.array([sentence_ids]), np.array([len(sentence_ids)])
assert log_perplexity(sentences, sizes) >= 0.0

feed = {
sentence1_ph: sentences,
sentence1_len_ph: sizes
}
tmp = session.run(sentence1_embedding, feed_dict=feed)
print(tmp.shape)
sentence_embedding = session.run(sentence1_embedding, feed_dict=feed)
assert sentence_embedding.shape == (1, len(sentence_ids), embedding_size)



if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
Expand Down
6 changes: 3 additions & 3 deletions data/schematic-memory/WN18RR/MD5SUMS
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
41c0f6926bf801787b012bfcf0c84954 test.txt
ba20bf22503b5f836d2e2dda02c9eaf6 train.txt
ebffd8145068bfbaec7a98a860396b80 valid.txt
2b45ba1ba436b9d4ff27f1d3511224c9 test.txt
35e81af3ae233327c52a87f23b30ad3c train.txt
74a2ee9eca9a8d31f1a7d4d95b5e0887 valid.txt
Loading

0 comments on commit 9153339

Please sign in to comment.