-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_copytask.py
117 lines (96 loc) · 3.92 KB
/
data_copytask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Adapted from https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py
import random
import torch
import numpy as np
import os
from argparse import ArgumentParser
import warnings
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
class CopyDataset(Dataset):
def __init__(self, num_batches, batch_size, seq_min_len, seq_max_len, seq_width):
self.num_batches = num_batches
self.batch_size = batch_size
self.seq_min_len = seq_min_len
self.seq_max_len = seq_max_len
self.seq_width = seq_width
self.seq_len = random.randint(self.seq_min_len, self.seq_max_len)
self.counter = 0
def __len__(self):
# return length
return self.num_batches * self.batch_size
def __getitem__(self, idx):
# change the seq_len for this batch
if self.counter == 0:
self.seq_len = random.randint(self.seq_min_len, self.seq_max_len)
seq = np.random.binomial(1, 0.5, (self.seq_len, self.seq_width))
seq = torch.from_numpy(seq)
outp = seq.clone()
# The input includes an additional channel used for the delimiter
inp = torch.zeros(self.seq_len + 1, self.seq_width + 1)
inp[:self.seq_len, :self.seq_width] = seq
inp[self.seq_len, self.seq_width] = 1.0 # delimiter in our control channel
# increase counter
self.counter = 0 if self.counter >= self.batch_size else + 1
return inp.float(), outp.float()
class CopyTaskDataModule(pl.LightningDataModule):
def __init__(self,
seq_width: int = 8,
seq_min_len: int = 1,
seq_max_len: int = 20,
train_batch_size: int = 1,
eval_batch_size: int = 1,
train_batches_per_epoch: int = 200,
val_batches: int = 50,
dataloader_num_workers: int = 4,
**kwargs
):
super().__init__()
# save hyperparameter
self.seq_width = seq_width
self.seq_min_len = seq_min_len
self.seq_max_len = seq_max_len
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.train_batches_per_epoch = train_batches_per_epoch
self.val_batches = val_batches
self.dataloader_num_workers = dataloader_num_workers
def setup(self, stage='fit'):
self.train_dataset = CopyDataset(
self.train_batches_per_epoch,
self.train_batch_size,
self.seq_min_len,
self.seq_max_len,
self.seq_width
)
self.eval_dataset = CopyDataset(
self.val_batches,
self.eval_batch_size,
self.seq_min_len,
self.seq_max_len,
self.seq_width
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.dataloader_num_workers,
)
def val_dataloader(self):
return DataLoader(
self.eval_dataset,
batch_size=self.eval_batch_size,
num_workers=self.dataloader_num_workers,
)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--seq_width', type=int, default=8)
parser.add_argument('--seq_min_len', type=int, default=1)
parser.add_argument('--seq_max_len', type=int, default=10)
parser.add_argument('--train_batch_size', type=int, default=16)
parser.add_argument('--eval_batch_size', type=int, default=16)
parser.add_argument('--train_batches_per_epoch', type=int, default=500)
parser.add_argument('--val_batches', type=int, default=100)
parser.add_argument('--dataloader_num_workers', type=int, default=4)
return parser