-
Notifications
You must be signed in to change notification settings - Fork 43
/
batchify.py
27 lines (24 loc) · 941 Bytes
/
batchify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
def get_batch(x, vocab, device):
go_x, x_eos = [], []
max_len = max([len(s) for s in x])
for s in x:
s_idx = [vocab.word2idx[w] if w in vocab.word2idx else vocab.unk for w in s]
padding = [vocab.pad] * (max_len - len(s))
go_x.append([vocab.go] + s_idx + padding)
x_eos.append(s_idx + [vocab.eos] + padding)
return torch.LongTensor(go_x).t().contiguous().to(device), \
torch.LongTensor(x_eos).t().contiguous().to(device) # time * batch
def get_batches(data, vocab, batch_size, device):
order = range(len(data))
z = sorted(zip(order, data), key=lambda i: len(i[1]))
order, data = zip(*z)
batches = []
i = 0
while i < len(data):
j = i
while j < min(len(data), i+batch_size) and len(data[j]) == len(data[i]):
j += 1
batches.append(get_batch(data[i: j], vocab, device))
i = j
return batches, order