-
Notifications
You must be signed in to change notification settings - Fork 2
/
array_tool.py
41 lines (33 loc) · 1.12 KB
/
array_tool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
'''return: shuffled data with indice'''
def paired_shuffle(dats, labels):
indice = [idx for idx in range(len(dats))]
np.random.shuffle(indice)
shuffled_dats = []
shuffled_labels = []
for idx in indice:
shuffled_dats.append(dats[idx])
shuffled_labels.append(labels[idx])
return shuffled_dats, shuffled_labels, indice
def queue_sort(queue):
input_ex =[]
while True:
try:
(i, dat) =queue.get_nowait()
input_ex.append((i, dat))
except:
break
input_ex =sorted(input_ex, key=lambda x: x[0])
return [item[1] for item in input_ex]
def kfold(n_sample, n_split=5, shuffle=False):
indice = [idx for idx in range(n_sample)]
if shuffle:
np.random.shuffle(indice)
split_size = n_sample // n_split
for ns in range(n_split):
if (ns+1)*split_size > n_sample:
idx_val = indice[ns*split_size:]
else:
idx_val = indice[ns*split_size: (ns+1)*split_size]
idx_train = [idx for idx in filter(lambda x: x not in idx_val, indice)]
yield idx_train, idx_val