7
7
from scipy import signal
8
8
import copy
9
9
10
+ from tqdm import tqdm
11
+
10
12
def set_seed (seed : int = 42 ) -> None :
11
13
np .random .seed (seed )
12
14
random .seed (seed )
@@ -45,33 +47,33 @@ class DatasetCreator:
45
47
sample_len: length in seconds of each generated sample.
46
48
overlap: percentage of overlapping between samples."""
47
49
def __init__ (self , preproc_data , label_kind ,
48
- physio_f , gaze_f , block_len , sample_len , overlap ):
50
+ physio_f , gaze_f , block_len , sample_len , overlap , verbose = False ):
49
51
self .preproc_data = preproc_data
50
52
self .label_kind = label_kind
51
53
self .physio_f = physio_f
52
54
self .gaze_f = gaze_f
53
55
self .block_len = block_len
54
56
self .sample_len = sample_len
55
57
self .overlap = overlap
58
+ self .verbose = verbose
56
59
57
60
def save_to_list (self ):
58
61
"""Grabs data from hierarchical structure and unpacks all values.
59
- Add gathered data to a list as as single sample."""
62
+ Add gathered data to a list as a single sample."""
60
63
sub_dir = list_files (self .preproc_data , sorted_dir = False )
61
64
data_list = []
62
65
labels = []
63
- for dir in sub_dir : # for each subject/folder
66
+ for dir in tqdm ( sub_dir , desc = 'Reading data' ) : # for each subject/folder
64
67
subj = int (dir [- 2 :])
65
- print ("Working on subject" , subj )
66
68
67
69
# Get labels for current subject and label kind
68
70
all_labels = np .genfromtxt (os .path .join (dir , 'labels_felt{}.csv' # return Dataframe
69
71
.format (self .label_kind )), delimiter = ',' )
70
72
71
73
# Get each original sample and create dataset samples
72
- id_trials = [x .split ("/ " )[- 1 ].partition ("_" )[0 ] for x in list_files (dir , sorted_dir = False )] # get beggining of files
74
+ id_trials = [x .split ("\\ " )[- 1 ].partition ("_" )[0 ] for x in list_files (dir , sorted_dir = False )] # get beggining of files
73
75
id_trials = sorted (np .unique (id_trials )[:- 1 ], key = lambda x : int (x )) # remove duplicates, "label", and sort
74
- for i , id in enumerate (id_trials ):
76
+ for i , id in enumerate (tqdm ( id_trials , desc = f'Subject { subj } ' ) ):
75
77
pupil_data = np .genfromtxt (os .path .join (dir , '{}_PUPIL.csv'
76
78
.format (id )), delimiter = ',' )
77
79
gaze_data = np .genfromtxt (os .path .join (dir , '{}_GAZE_COORD.csv'
@@ -110,7 +112,8 @@ def save_to_list(self):
110
112
gsr = gsr_data [k : k + n_points_sample_physio ]
111
113
eeg = eeg_data [k : k + n_points_sample_physio ]
112
114
ecg = ecg_data [k : k + n_points_sample_physio ]
113
-
115
+
116
+
114
117
if (len (pupil ) != n_points_sample_gaze or len (gaze_coord ) != n_points_sample_gaze or len (eye_dist ) != n_points_sample_gaze or
115
118
len (gsr ) != n_points_sample_physio or len (eeg ) != n_points_sample_physio or len (ecg ) != n_points_sample_physio ):
116
119
# sanity check on the samples
@@ -121,8 +124,9 @@ def save_to_list(self):
121
124
clean_gaze_coord = gaze_coord [gaze_coord != - 1 ]
122
125
clean_eye_dist = eye_dist [eye_dist != - 1 ]
123
126
if len (clean_pupil )/ len (pupil ) < 0.6 or len (clean_gaze_coord )/ len (gaze_coord ) < 0.6 or len (clean_eye_dist )/ len (eye_dist ) < 0.6 :
124
- print ("\033 [93mGaze segment too noisy for subject: {}, sample: {}, segment:{}!\033 [0m"
125
- .format (subj , id , str (j // (n_points_sample_gaze - overlap_step_gaze ))))
127
+ if self .verbose :
128
+ print ("\033 [93mGaze segment too noisy for subject: {}, sample: {}, segment:{}!\033 [0m"
129
+ .format (subj , id , str (j // (n_points_sample_gaze - overlap_step_gaze ))))
126
130
continue
127
131
128
132
# Create single variable containing all gaze information
@@ -145,7 +149,6 @@ def save_to_list(self):
145
149
146
150
return data_list , labels
147
151
148
-
149
152
def std_for_SNR (signal , noise , snr ):
150
153
'''Compute the gain to be applied to the noise to achieve the given SNR in dB'''
151
154
signal_power = np .var (signal .numpy ())
@@ -220,32 +223,39 @@ def load_dataset(data, labels, scaling, noise, m, SNR):
220
223
221
224
if __name__ == '__main__' :
222
225
parser = argparse .ArgumentParser ()
223
- parser .add_argument ('--path_to_csv ' , type = str )
224
- parser .add_argument ('--save_path' , type = str )
225
- parser .add_argument ('--label_kind' , type = str , default = 'Vlnc ' , help = "Choose valence (Vlnc) or arousal (Arsl) label" )
226
+ parser .add_argument ('--preproc_data_path ' , type = str , default = 'hci-tagging-database/preproc_data' , help = 'Path to folder where preprocessed data was saved' )
227
+ parser .add_argument ('--save_path' , type = str , default = 'hci-tagging-database/torch_datasets' , help = 'Path to save .pt files' )
228
+ parser .add_argument ('--label_kind' , type = str , default = 'Arsl ' , help = "Choose valence (Vlnc) or arousal (Arsl) label" )
226
229
parser .add_argument ('--seed' , type = int , default = 0 )
230
+ parser .add_argument ('--verbose' , type = bool , action = argparse .BooleanOptionalAction , default = False )
227
231
args = parser .parse_args ()
228
232
229
233
assert args .label_kind in ["Arsl" , "Vlnc" ]
234
+ print ("Creating dataset for label: " , args .label_kind )
230
235
231
236
set_seed (args .seed )
232
237
233
- d = DatasetCreator (args .path_to_csv , args .label_kind , physio_f = 128 , gaze_f = 60 , block_len = 30 , sample_len = 10 , overlap = 0 ) # create object
238
+ d = DatasetCreator (args .preproc_data_path , args .label_kind , physio_f = 128 , gaze_f = 60 , block_len = 30 , sample_len = 10 , overlap = 0 , verbose = args . verbose ) # create object
234
239
data , labels = d .save_to_list () # call method
240
+ print (len (data ))
235
241
236
242
test_size = 0.2
237
243
X_train , X_test , y_train , y_test = train_test_split (data , labels , test_size = test_size , random_state = args .seed , stratify = labels )
238
244
245
+ # Augmentation
239
246
m = 30 # Number of augmented signals for each original sample
240
247
SNR = 5
241
248
242
249
train_data = load_dataset (X_train , y_train , scaling = True , noise = True , m = m , SNR = SNR )
243
- test_data = load_dataset (X_test , y_test , False , False , 1 , None )
250
+ test_data = load_dataset (X_test , y_test , scaling = False , noise = False , m = 1 , SNR = None )
244
251
245
252
print ("Len train before augmentation: " , len (X_train ))
246
253
print ("Len train after augmentation: " , len (train_data ))
247
254
print ("Len test: " , len (test_data ))
248
255
print ("Tot dataset: " , len (train_data ) + len (test_data ))
249
256
257
+ if not os .path .exists (args .save_path ):
258
+ os .makedirs (args .save_path )
259
+
250
260
torch .save (train_data , f'{ args .save_path } /train_augmented_data_{ args .label_kind } .pt' )
251
- torch .save (test_data , f'{ args .save_path } /test_augmented_data_ { args .label_kind } .pt' )
261
+ torch .save (test_data , f'{ args .save_path } /test_data_ { args .label_kind } .pt' )
0 commit comments