forked from ntrang086/image_captioning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
301 lines (243 loc) · 10.9 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import sys
import os
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable
from torchvision import transforms
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
from data_loader import MyDataLoader
from model import EncoderCNN, DecoderRNN
class Trainer:
"""The Trainer encapsulates the model training process."""
def __init__(self, train_loader, val_loader, encoder, decoder, optimizer, criterion=None, start_epoch=0, rounds=1):
"""Initialize the Trainer state. This includes loading the model data if start_epoch > 0."""
self.train_loader = train_loader
self.val_loader = val_loader
self.encoder = encoder
self.decoder = decoder
self.epoch = start_epoch
self.rounds = rounds
self.current_state_file = os.path.join('./models', 'current-model.pkl')
self.optimizer = optimizer
self.vocab = self.train_loader.vocab
self.vocab_size = len(self.vocab)
if criterion is None:
pad_idx = self.vocab.word2idx['<pad>']
self.criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
else:
self.criterion = criterion
if torch.cuda.is_available():
self.map_location = torch.device('cuda')
self.criterion.cuda()
self.encoder.cuda()
self.decoder.cuda()
else:
self.map_location = torch.device('cpu')
self.cider = []
if self.epoch > 0:
self.load()
def load(self):
"""Load the model output of an epoch."""
checkpoint = torch.load(self.current_state_file, map_location=self.map_location)
# Load the pre-trained weights
self.encoder.load_state_dict(checkpoint['encoder'])
self.decoder.load_state_dict(checkpoint['decoder'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.epoch = checkpoint['epoch']
self.cider = checkpoint['cider']
print('Successfully loaded epoch {}'.format(self.epoch))
def save_as(self, file_name):
"""Save the training state in a pickle file.
The following values are saved:
- encoder parameter,
- decoder parameters,
- optimizer state,
- current epoch,
- list of CIDEr scores from the evaluation of past epochs.
Parameters
----------
file_name : str
Name of the file to save.
"""
torch.save({"encoder": self.encoder.state_dict(),
"decoder": self.decoder.state_dict(),
"optimizer": self.optimizer.state_dict(),
"cider": self.cider,
"epoch": self.epoch
}, file_name)
def save(self):
"""Save the training state in a pickle file.
The following values are saved:
- encoder parameter,
- decoder parameters,
- optimizer state,
- current epoch,
- list of CIDEr scores from the evaluation of past epochs.
"""
self.save_as(os.path.join("./models", "current-model.pkl"))
self.save_as(os.path.join("./models", "epoch-model-{}.pkl".format(self.epoch)))
def clean_sentence(self, word_idx_list):
"""Take a list of word ids and a vocabulary from a dataset as inputs
and return the corresponding sentence as a single Python string.
Parameters
----------
word_idx_list : list
List of word indices, i.e. embedded words.
"""
sentence = []
for i in range(len(word_idx_list)):
vocab_id = word_idx_list[i]
word = self.vocab.idx2word[vocab_id]
if word == self.vocab.end_word:
break
if word != self.vocab.start_word:
sentence.append(word)
sentence = " ".join(sentence)
return sentence
def train(self):
"""Train the model for one epoch using the provided parameters. Return the epoch's average train loss."""
# Switch to train mode
self.encoder.train()
self.decoder.train()
# Keep track of train loss
total_loss = 0
# Start time for every 100 steps
start_train_time = time.time()
i_step = 0
# Obtain the batch
pbar = tqdm(self.train_loader)
pbar.set_description('training epoch {}'.format(self.epoch))
for batch in pbar:
i_step += 1
images, captions, lengths = batch[0], batch[1], batch[2]
# Move to GPU if CUDA is available
if torch.cuda.is_available():
images = images.cuda()
captions = captions.cuda()
lengths = lengths.cuda()
# Pass the inputs through the CNN-RNN model
features = self.encoder(images)
outputs = self.decoder(features, captions, lengths)
# Calculate the batch loss
# Flatten batch dimension
outputs = outputs.view(-1, vocab_size)
captions = captions.view(-1)
loss = self.criterion(outputs, captions)
# Zero the gradients. Since the backward() function accumulates
# gradients, and we don’t want to mix up gradients between minibatches,
# we have to zero them out at the start of a new minibatch
self.optimizer.zero_grad()
# Backward pass to calculate the weight gradients
loss.backward()
# Update the parameters in the optimizer
self.optimizer.step()
total_loss += loss.item()
pbar.set_postfix(last=loss.item(), avg=total_loss/i_step)
self.epoch += 1
self.save()
return total_loss / i_step
def evaluate(self):
"""Evaluate the model for one epoch using the provided parameters.
Return the epoch's average CIDEr score."""
# Switch to validation mode
self.encoder.eval()
self.decoder.eval()
cocoRes = COCO()
anns = []
# Disable gradient calculation because we are in inference mode
with torch.no_grad():
pbar = tqdm(self.val_loader)
pbar.set_description('evaluating epoch {}'.format(self.epoch));
for batch in pbar:
images, img_id = batch[0], batch[3]
# Move to GPU if CUDA is available
if torch.cuda.is_available():
images = images.cuda()
# Pass the inputs through the CNN-RNN model
features = encoder(images).unsqueeze(1)
for i in range(img_id.size()[0]):
slice = features[i].unsqueeze(0)
outputs = decoder.sample_beam_search(slice)
sentence = self.clean_sentence(outputs[0])
id = img_id[i].item()
#print('id: {}, cap: {}'.format(id, sentence))
anns.append({'image_id': id, 'caption': sentence})
for id, ann in enumerate(anns):
ann['id'] = id
cocoRes.dataset['annotations'] = anns
cocoRes.createIndex()
cocoEval = COCOEvalCap(self.val_loader.coco_dataset.coco, cocoRes)
imgIds = set([ann['image_id'] for ann in cocoRes.dataset['annotations']])
cocoEval.params['image_id'] = imgIds
cocoEval.evaluate()
cider = cocoEval.eval['CIDEr']
old_max = 0
if len(self.cider) > 0:
old_max = max(self.cider)
if len(self.cider) < self.epoch:
self.cider.append(cider)
else:
self.cider[self.epoch-1] = cider
self.save()
print("DEBUG: self.epoch: {}, self.cider: {}".format(self.epoch, self.cider))
if cider > old_max:
print('CIDEr improved: {:.2f} => {:.2f}'.format(old_max, cider))
self.save_as(os.path.join("./models", "best-model.pkl"))
return self.cider[self.epoch-1]
# Set values for the training variables
batch_size = 128 # batch size
vocab_threshold = 5 # minimum word count threshold
vocab_from_file = True # if True, load existing vocab file
embed_size = 512 # dimensionality of image and word embeddings
hidden_size = 512 # number of features in hidden state of the RNN decoder
num_epochs = 10 # number of training epochs
# Define a transform to pre-process the training images
transform_train = transforms.Compose([
transforms.Resize(256), # smaller edge of image resized to 256
transforms.RandomCrop(224), # get 224x224 crop from random location
transforms.RandomHorizontalFlip(), # horizontally flip image with probability=0.5
transforms.ToTensor(), # convert the PIL Image to a tensor
transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model
(0.229, 0.224, 0.225))])
# Build data train_loader, applying the transforms
train_loader = MyDataLoader(transform=transform_train,
mode='train',
batch_size=batch_size,
vocab_threshold=vocab_threshold,
vocab_from_file=vocab_from_file)
transform_val = transforms.Compose([
transforms.Resize(256), # smaller edge of image resized to 256
transforms.CenterCrop(224), # get 224x224 crop from the center
transforms.ToTensor(), # convert the PIL Image to a tensor
transforms.Normalize((0.485, 0.456, 0.406), # normalize image for pre-trained model
(0.229, 0.224, 0.225))])
val_loader = MyDataLoader(transform=transform_val,
mode='val',
batch_size=batch_size,
vocab_threshold=vocab_threshold,
vocab_from_file=True)
# The size of the vocabulary
vocab_size = len(train_loader.vocab)
# Initialize the encoder and decoder
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)
# Specify the learnable parameters of the model
params = list(decoder.parameters()) + list(encoder.embed.parameters()) + list(encoder.bn.parameters()) + list(encoder.resnet.parameters())
# Define the optimizer
optimizer = torch.optim.AdamW(params=params, lr=0.001, weight_decay=0.05, amsgrad=True)
trainer = Trainer(train_loader, val_loader, encoder, decoder, optimizer)
if not os.path.exists(trainer.current_state_file):
trainer.train()
trainer.load()
# if cider is missing for current epoch, evaluater first
if len(trainer.cider) < trainer.epoch:
print('Epoch {} not yet evaluated'.format(trainer.epoch))
trainer.evaluate()
for i in range(num_epochs):
trainer.train()
trainer.evaluate()