forked from sanyam5/skip-thoughts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vocab.py
81 lines (65 loc) · 2.22 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
This code has been taken and modified from https://github.com/ryankiros/skip-thoughts
Constructing and loading dictionaries
"""
import _pickle as pkl
from collections import OrderedDict
import argparse
def build_dictionary(text):
"""
Build a dictionary
text: list of sentences (pre-tokenized)
"""
wordcount = {}
for cc in text:
words = cc.split()
for w in words:
if w not in wordcount:
wordcount[w] = 0
wordcount[w] += 1
sorted_words = sorted(list(wordcount.keys()), key=lambda x: wordcount[x], reverse=True)
worddict = OrderedDict()
for idx, word in enumerate(sorted_words):
worddict[word] = idx+2 # 0: <eos>, 1: <unk>
return worddict, wordcount
def load_dictionary(loc='./data/book_dictionary_large.pkl'):
"""
Load a dictionary
"""
with open(loc, 'rb') as f:
worddict = pkl.load(f)
return worddict
def save_dictionary(worddict, wordcount, loc='./data/book_dictionary_large.pkl'):
"""
Save a dictionary to the specified location
"""
with open(loc, 'wb') as f:
pkl.dump(worddict, f)
pkl.dump(wordcount, f)
def build_and_save_dictionary(text, source):
save_loc = source+".pkl"
try:
cached = load_dictionary(save_loc)
print("Using cached dictionary at {}".format(save_loc))
return cached
except:
pass
# build again and save
print("unable to load from cached, building fresh")
worddict, wordcount = build_dictionary(text)
print("Got {} unique words".format(len(worddict)))
print("Saveing dictionary at {}".format(save_loc))
save_dictionary(worddict, wordcount, save_loc)
return worddict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("text_file", type=str)
args = parser.parse_args()
print("Extracting text from {}".format(args.text_file))
text = open(args.text_file, "rt").readlines()
print("Extracting dictionary..")
worddict, wordcount = build_dictionary(text)
out_file = args.text_file+".pkl"
print("Got {} unique words. Saving to file {}".format(len(worddict), out_file))
save_dictionary(worddict, wordcount, out_file)
print("Done.")