forked from ntrang086/image_captioning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvocabulary.py
117 lines (103 loc) · 4.33 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import nltk
import pickle
import os.path
from pycocotools.coco import COCO
from pexel import PEXEL
from collections import Counter
class Vocabulary(object):
def __init__(self,
vocab_threshold,
vocab_file="./vocab.pkl",
start_word="<start>",
end_word="<end>",
unk_word="<unk>",
pad_word="<pad>",
vocab_from_file=False,
coco_annotations_file="./cocoapi/annotations/captions_train2014.json",
pexel_annotations_file="/home/cgawron/pexels/pexels.json"):
"""Initialize the vocabulary.
Paramters:
vocab_threshold: Minimum word count threshold.
vocab_file: File containing the vocabulary.
start_word: Special word denoting sentence start.
end_word: Special word denoting sentence end.
unk_word: Special word denoting unknown words.
annotations_file: Path for train annotation file.
vocab_from_file: If False, create vocab from scratch & override any
existing vocab_file. If True, load vocab from from
existing vocab_file, if it exists.
"""
self.vocab_threshold = vocab_threshold
self.vocab_file = vocab_file
self.start_word = start_word
self.end_word = end_word
self.unk_word = unk_word
self.pad_word = pad_word
self.pexel_annotations_file = pexel_annotations_file
self.coco_annotations_file = coco_annotations_file
self.vocab_from_file = vocab_from_file
self.get_vocab()
def get_vocab(self):
"""Load the vocabulary from file or build it from scratch."""
if os.path.exists(self.vocab_file) & self.vocab_from_file:
with open(self.vocab_file, "rb") as f:
vocab = pickle.load(f)
self.word2idx = vocab.word2idx
self.idx2word = vocab.idx2word
print("Vocabulary successfully loaded from vocab.pkl file!")
else:
print("Building vocabulary: {}, {}".format(os.path.exists(self.vocab_file), self.vocab_from_file))
self.build_vocab()
with open(self.vocab_file, "wb") as f:
pickle.dump(self, f)
def build_vocab(self):
"""Populate the dictionaries for converting tokens to integers
(and vice-versa)."""
self.init_vocab()
self.add_word(self.pad_word)
self.add_word(self.start_word)
self.add_word(self.end_word)
self.add_word(self.unk_word)
self.add_captions()
def init_vocab(self):
"""Initialize the dictionaries for converting tokens to integers
(and vice-versa)."""
self.word2idx = {}
self.idx2word = {}
self.idx = 0
def add_word(self, word):
"""Add a token to the vocabulary."""
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def add_captions(self):
"""Loop over training captions and add all tokens to the vocabulary
that meet or exceed the threshold."""
coco = COCO(self.coco_annotations_file)
counter = Counter()
ids = coco.anns.keys()
for i, id in enumerate(ids):
caption = str(coco.anns[id]["caption"])
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)
if i % 100000 == 0:
print("[%d/%d] Tokenizing captions..." % (i, len(ids)))
pexel = PEXEL(self.pexel_annotations_file)
ids = pexel.anns.keys()
for i, id in enumerate(ids):
caption = str(pexel.anns[id])
tokens = nltk.tokenize.word_tokenize(caption.lower())
counter.update(tokens)
if i % 100000 == 0:
print("[%d/%d] Tokenizing captions..." % (i, len(ids)))
words = [word for word, cnt in counter.items()
if cnt >= self.vocab_threshold]
for i, word in enumerate(words):
self.add_word(word)
def __call__(self, word):
if not word in self.word2idx:
return self.word2idx[self.unk_word]
return self.word2idx[word]
def __len__(self):
return len(self.word2idx)