Skip to content

Commit a36b91c

Browse files
authored
Update vqa_gen_dataset.py, set max_tgt_length
1 parent b358cd6 commit a36b91c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

data/mm_data/vqa_gen_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def __getitem__(self, index):
163163
ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in ref.split('&&')}
164164
answer = max(ref_dict, key=ref_dict.get)
165165
conf = torch.tensor([ref_dict[answer]])
166-
tgt_item = self.encode_text(" {}".format(answer))
166+
tgt_item = self.encode_text(" {}".format(answer), length=self.max_tgt_length)
167167

168168
if self.add_object and predict_objects is not None:
169169
predict_object_seq = ' '.join(predict_objects.strip().split('&&')[:self.max_object_length])

0 commit comments

Comments
 (0)