Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Pasquale Minervini committed Nov 30, 2017
1 parent 72747ab commit e8af02e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
17 changes: 17 additions & 0 deletions inferbeddings/lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ def loop(prev, _):

self.random_state = np.random.RandomState(seed)

def score_sequence(self, session, sequence):
x = np.zeros((1, 1))
state = session.run(self.cell.zero_state(1, tf.float32))
res = 0.0
for i, idx in enumerate(sequence):
x[0, 0] = idx
feed = {
self.input_data: x,
self.initial_state: state
}
probabilities, state = session.run([self.probabilities, self.final_state], feed)
if i < len(sequence) - 1:
next_idx = sequence[i + 1]
res += np.log(probabilities[0, next_idx])
print(res)
return res

def sample(self, session, words, vocab, num=200, prime='first all', sampling_type=1, pick=0, width=4):
def weighted_pick(weights):
t = np.cumsum(weights)
Expand Down
6 changes: 5 additions & 1 deletion tests/inferbeddings/lm/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_lm_sample():
ckpt = tf.train.get_checkpoint_state(lm_path)
saver.restore(session, ckpt.model_checkpoint_path)

for _ in range(32):
for _ in range(4):
sample_value = imodel.sample(session=session,
words=index_to_token,
vocab=token_to_index,
Expand All @@ -76,6 +76,10 @@ def test_lm_sample():
width=4)
print(sample_value)

sequence = [token_to_index[w] for w in ['A', 'girl', 'runs']]
imodel.score_sequence(session=session, sequence=sequence)


if __name__ == '__main__':
#pytest.main([__file__])
test_lm_sample()

0 comments on commit e8af02e

Please sign in to comment.