-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdata_loader.py
executable file
·81 lines (62 loc) · 2.79 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import random
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from utils import PAD_ID, UNK_ID, SOS_ID, EOS_ID
import numpy as np
class DialogDataset(Dataset):
def __init__(self, sentences, conversation_length, sentence_length, vocab, data=None):
# [total_data_size, max_conversation_length, max_sentence_length]
# tokenized raw text of sentences
self.sentences = sentences
self.vocab = vocab
# conversation length of each batch
# [total_data_size]
self.conversation_length = conversation_length
# list of length of sentences
# [total_data_size, max_conversation_length]
self.sentence_length = sentence_length
self.data = data
self.len = len(sentences)
def __getitem__(self, index):
"""Return Single data sentence"""
# [max_conversation_length, max_sentence_length]
sentence = self.sentences[index]
conversation_length = self.conversation_length[index]
sentence_length = self.sentence_length[index]
# word => word_ids
sentence = self.sent2id(sentence)
return sentence, conversation_length, sentence_length
def __len__(self):
return self.len
def sent2id(self, sentences):
"""word => word id"""
# [max_conversation_length, max_sentence_length]
return [self.vocab.sent2id(sentence) for sentence in sentences]
def get_loader(sentences, conversation_length, sentence_length, vocab, batch_size=100, data=None, shuffle=True):
"""Load DataLoader of given DialogDataset"""
def collate_fn(data):
"""
Collate list of data in to batch
Args:
data: list of tuple(source, target, conversation_length, source_length, target_length)
Return:
Batch of each feature
- source (LongTensor): [batch_size, max_conversation_length, max_source_length]
- target (LongTensor): [batch_size, max_conversation_length, max_source_length]
- conversation_length (np.array): [batch_size]
- source_length (LongTensor): [batch_size, max_conversation_length]
"""
# Sort by conversation length (descending order) to use 'pack_padded_sequence'
data.sort(key=lambda x: x[1], reverse=True)
# Separate
sentences, conversation_length, sentence_length = zip(*data)
# return sentences, conversation_length, sentence_length.tolist()
return sentences, conversation_length, sentence_length
dataset = DialogDataset(sentences, conversation_length,
sentence_length, vocab, data=data)
data_loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn)
return data_loader