-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsplit_data.py
39 lines (32 loc) · 975 Bytes
/
split_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import sys
import os
import errno
import random
from text.utils import deps_from_tsv, deps_to_tsv
# for reproducibility
random.seed(42)
prop_train = 0.1 # proportion of the data used for training
prop_valid = 0.01
def prepare(fname, expr_dir):
print('| read in the data')
data = deps_from_tsv(fname)
print('| shuffling')
random.shuffle(data)
n_train = int(len(data) * prop_train)
n_valid = int(len(data) * prop_valid)
train = data[:n_train]
valid = data[n_train: n_train+n_valid]
test = data[n_train+n_valid:]
try:
os.mkdir(expr_dir)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise
pass
print('| splitting')
deps_to_tsv(train, os.path.join(expr_dir, 'train.tsv'))
deps_to_tsv(valid, os.path.join(expr_dir, 'valid.tsv'))
deps_to_tsv(test, os.path.join(expr_dir, 'test.tsv'))
print('| done!')
if __name__ == '__main__':
prepare(sys.argv[1], sys.argv[2])