Skip to content

Commit fe6b035

Browse files
committed
Added script to compute dialogue embeddings
1 parent 9f17938 commit fe6b035

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

compute_dialogue_embeddings.py

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
#!/usr/bin/env python
2+
"""
3+
This script computes dialogue embeddings for dialogues found in a text file.
4+
"""
5+
6+
#!/usr/bin/env python
7+
8+
import argparse
9+
import cPickle
10+
import traceback
11+
import logging
12+
import time
13+
import sys
14+
import math
15+
16+
import os
17+
import numpy
18+
import codecs
19+
import search
20+
import utils
21+
22+
from dialog_encdec import DialogEncoderDecoder
23+
from numpy_compat import argpartition
24+
from state import prototype_state
25+
26+
logger = logging.getLogger(__name__)
27+
28+
class Timer(object):
29+
def __init__(self):
30+
self.total = 0
31+
32+
def start(self):
33+
self.start_time = time.time()
34+
35+
def finish(self):
36+
self.total += time.time() - self.start_time
37+
38+
def parse_args():
39+
parser = argparse.ArgumentParser("Compute dialogue embeddings from model")
40+
41+
parser.add_argument("model_prefix",
42+
help="Path to the model prefix (without _model.npz or _state.pkl)")
43+
44+
parser.add_argument("dialogues",
45+
help="File of input dialogues (tab separated)")
46+
47+
parser.add_argument("output",
48+
help="Output file")
49+
50+
parser.add_argument("--verbose",
51+
action="store_true", default=False,
52+
help="Be verbose")
53+
54+
parser.add_argument("--use-second-last-state",
55+
action="store_true", default=False,
56+
help="Outputs the second last dialogue encoder state instead of the last one")
57+
58+
return parser.parse_args()
59+
60+
def compute_encodings(joined_contexts, model, model_compute_encoding, output_second_last_state = False):
61+
context = numpy.zeros((model.seqlen, len(joined_contexts)), dtype='int32')
62+
context_lengths = numpy.zeros(len(joined_contexts), dtype='int32')
63+
for idx in range(len(joined_contexts)):
64+
context_lengths[idx] = len(joined_contexts[idx])
65+
if context_lengths[idx] < model.seqlen:
66+
context[:context_lengths[idx], idx] = joined_contexts[idx]
67+
else:
68+
# If context is longer tha max context, truncate it and force the end-of-utterance token at the end
69+
context[:model.seqlen, idx] = joined_contexts[idx][0:model.seqlen]
70+
context[model.seqlen-1, idx] = model.eos_sym
71+
context_lengths[idx] = model.seqlen
72+
73+
n_samples = len(joined_contexts)
74+
75+
# Generate the reversed context
76+
reversed_context = numpy.copy(context)
77+
for idx in range(context.shape[1]):
78+
eos_indices = numpy.where(context[:, idx] == model.eos_sym)[0]
79+
prev_eos_index = -1
80+
for eos_index in eos_indices:
81+
reversed_context[(prev_eos_index+2):eos_index, idx] = (reversed_context[(prev_eos_index+2):eos_index, idx])[::-1]
82+
prev_eos_index = eos_index
83+
84+
# Recompute hs only for those particular sentences
85+
# that met the end-of-sentence token
86+
87+
encoder_states = model_compute_encoding(context, reversed_context, model.seqlen)
88+
hs = encoder_states[1]
89+
90+
if output_second_last_state:
91+
second_last_hidden_state = numpy.zeros((hs.shape[1], hs.shape[2]), dtype='float64')
92+
for i in range(hs.shape[1]):
93+
second_last_hidden_state[i, :] = hs[context_lengths[i] - 1, i, :]
94+
return second_last_hidden_state
95+
else:
96+
return hs[-1, :, :]
97+
98+
99+
def main():
100+
args = parse_args()
101+
state = prototype_state()
102+
103+
state_path = args.model_prefix + "_state.pkl"
104+
model_path = args.model_prefix + "_model.npz"
105+
106+
with open(state_path) as src:
107+
state.update(cPickle.load(src))
108+
109+
logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
110+
111+
model = DialogEncoderDecoder(state)
112+
113+
if os.path.isfile(model_path):
114+
logger.debug("Loading previous model")
115+
model.load(model_path)
116+
else:
117+
raise Exception("Must specify a valid model path")
118+
119+
contexts = [[]]
120+
lines = open(args.dialogues, "r").readlines()
121+
if len(lines):
122+
contexts = [x.strip().split('\t') for x in lines]
123+
124+
model_compute_encoding = model.build_encoder_function()
125+
dialogue_encodings = []
126+
127+
# Start loop
128+
joined_contexts = []
129+
batch_index = 0
130+
batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
131+
for context_id, context_sentences in enumerate(contexts):
132+
133+
# Convert contextes into list of ids
134+
joined_context = []
135+
136+
if len(context_sentences) == 0:
137+
joined_context = [model.eos_sym]
138+
else:
139+
for sentence in context_sentences:
140+
sentence_ids = model.words_to_indices(sentence.split())
141+
# Add sos and eos tokens
142+
joined_context += [model.sos_sym] + sentence_ids + [model.eos_sym]
143+
144+
# HACK
145+
for i in range(0, 50):
146+
joined_context += [model.sos_sym] + [0] + [model.eos_sym]
147+
148+
joined_contexts.append(joined_context)
149+
150+
if len(joined_contexts) == model.bs:
151+
batch_index = batch_index + 1
152+
logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_index, batch_total))
153+
encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
154+
for i in range(len(encs)):
155+
dialogue_encodings.append(encs[i])
156+
157+
joined_contexts = []
158+
159+
160+
if len(joined_contexts) > 0:
161+
logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_total, batch_total))
162+
encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
163+
for i in range(len(encs)):
164+
dialogue_encodings.append(encs[i])
165+
166+
# Save encodings to disc
167+
cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w'))
168+
169+
if __name__ == "__main__":
170+
main()
171+

0 commit comments

Comments
 (0)