Skip to content

Commit 36b6f2e

Browse files
committed
Updated code
1 parent 178b05f commit 36b6f2e

File tree

8 files changed

+107
-82
lines changed

8 files changed

+107
-82
lines changed

data/create_dataset.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from scipy import signal
88
import copy
99

10+
from tqdm import tqdm
11+
1012
def set_seed(seed: int = 42) -> None:
1113
np.random.seed(seed)
1214
random.seed(seed)
@@ -45,33 +47,33 @@ class DatasetCreator:
4547
sample_len: length in seconds of each generated sample.
4648
overlap: percentage of overlapping between samples."""
4749
def __init__(self, preproc_data, label_kind,
48-
physio_f, gaze_f, block_len, sample_len, overlap):
50+
physio_f, gaze_f, block_len, sample_len, overlap, verbose=False):
4951
self.preproc_data = preproc_data
5052
self.label_kind = label_kind
5153
self.physio_f = physio_f
5254
self.gaze_f = gaze_f
5355
self.block_len = block_len
5456
self.sample_len = sample_len
5557
self.overlap = overlap
58+
self.verbose = verbose
5659

5760
def save_to_list(self):
5861
"""Grabs data from hierarchical structure and unpacks all values.
59-
Add gathered data to a list as as single sample."""
62+
Add gathered data to a list as a single sample."""
6063
sub_dir = list_files(self.preproc_data, sorted_dir=False)
6164
data_list = []
6265
labels = []
63-
for dir in sub_dir: # for each subject/folder
66+
for dir in tqdm(sub_dir, desc='Reading data'): # for each subject/folder
6467
subj = int(dir[-2:])
65-
print("Working on subject", subj)
6668

6769
# Get labels for current subject and label kind
6870
all_labels = np.genfromtxt(os.path.join(dir, 'labels_felt{}.csv' # return Dataframe
6971
.format(self.label_kind)), delimiter=',')
7072

7173
# Get each original sample and create dataset samples
72-
id_trials = [x.split("/")[-1].partition("_")[0] for x in list_files(dir, sorted_dir=False)] # get beggining of files
74+
id_trials = [x.split("\\")[-1].partition("_")[0] for x in list_files(dir, sorted_dir=False)] # get beggining of files
7375
id_trials = sorted(np.unique(id_trials)[:-1], key=lambda x: int(x)) # remove duplicates, "label", and sort
74-
for i, id in enumerate(id_trials):
76+
for i, id in enumerate(tqdm(id_trials, desc=f'Subject {subj}')):
7577
pupil_data = np.genfromtxt(os.path.join(dir, '{}_PUPIL.csv'
7678
.format(id)), delimiter=',')
7779
gaze_data = np.genfromtxt(os.path.join(dir, '{}_GAZE_COORD.csv'
@@ -110,7 +112,8 @@ def save_to_list(self):
110112
gsr = gsr_data[k : k + n_points_sample_physio]
111113
eeg = eeg_data[k : k + n_points_sample_physio]
112114
ecg = ecg_data[k : k + n_points_sample_physio]
113-
115+
116+
114117
if (len(pupil) != n_points_sample_gaze or len(gaze_coord) != n_points_sample_gaze or len(eye_dist) != n_points_sample_gaze or
115118
len(gsr) != n_points_sample_physio or len(eeg) != n_points_sample_physio or len(ecg) != n_points_sample_physio):
116119
# sanity check on the samples
@@ -121,8 +124,9 @@ def save_to_list(self):
121124
clean_gaze_coord = gaze_coord[gaze_coord != -1]
122125
clean_eye_dist = eye_dist[eye_dist != -1]
123126
if len(clean_pupil)/len(pupil) < 0.6 or len(clean_gaze_coord)/len(gaze_coord) < 0.6 or len(clean_eye_dist)/len(eye_dist) < 0.6:
124-
print("\033[93mGaze segment too noisy for subject: {}, sample: {}, segment:{}!\033[0m"
125-
.format(subj, id, str(j//(n_points_sample_gaze - overlap_step_gaze))))
127+
if self.verbose:
128+
print("\033[93mGaze segment too noisy for subject: {}, sample: {}, segment:{}!\033[0m"
129+
.format(subj, id, str(j//(n_points_sample_gaze - overlap_step_gaze))))
126130
continue
127131

128132
# Create single variable containing all gaze information
@@ -145,7 +149,6 @@ def save_to_list(self):
145149

146150
return data_list, labels
147151

148-
149152
def std_for_SNR(signal, noise, snr):
150153
'''Compute the gain to be applied to the noise to achieve the given SNR in dB'''
151154
signal_power = np.var(signal.numpy())
@@ -220,32 +223,39 @@ def load_dataset(data, labels, scaling, noise, m, SNR):
220223

221224
if __name__ == '__main__':
222225
parser = argparse.ArgumentParser()
223-
parser.add_argument('--path_to_csv', type=str)
224-
parser.add_argument('--save_path', type=str)
225-
parser.add_argument('--label_kind', type=str, default='Vlnc', help="Choose valence (Vlnc) or arousal (Arsl) label")
226+
parser.add_argument('--preproc_data_path', type=str, default='hci-tagging-database/preproc_data', help='Path to folder where preprocessed data was saved')
227+
parser.add_argument('--save_path', type=str, default='hci-tagging-database/torch_datasets', help='Path to save .pt files')
228+
parser.add_argument('--label_kind', type=str, default='Arsl', help="Choose valence (Vlnc) or arousal (Arsl) label")
226229
parser.add_argument('--seed', type=int, default=0)
230+
parser.add_argument('--verbose', type=bool, action=argparse.BooleanOptionalAction, default=False)
227231
args = parser.parse_args()
228232

229233
assert args.label_kind in ["Arsl", "Vlnc"]
234+
print("Creating dataset for label: ", args.label_kind)
230235

231236
set_seed(args.seed)
232237

233-
d = DatasetCreator(args.path_to_csv, args.label_kind, physio_f = 128, gaze_f = 60, block_len = 30, sample_len=10, overlap = 0) # create object
238+
d = DatasetCreator(args.preproc_data_path, args.label_kind, physio_f = 128, gaze_f = 60, block_len = 30, sample_len=10, overlap = 0, verbose=args.verbose) # create object
234239
data, labels = d.save_to_list() # call method
240+
print(len(data))
235241

236242
test_size = 0.2
237243
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=test_size, random_state=args.seed, stratify=labels)
238244

245+
# Augmentation
239246
m = 30 # Number of augmented signals for each original sample
240247
SNR = 5
241248

242249
train_data = load_dataset(X_train, y_train, scaling=True, noise=True, m=m, SNR=SNR)
243-
test_data = load_dataset(X_test, y_test, False, False, 1, None)
250+
test_data = load_dataset(X_test, y_test, scaling=False, noise=False, m=1, SNR=None)
244251

245252
print("Len train before augmentation: ", len(X_train))
246253
print("Len train after augmentation: ", len(train_data))
247254
print("Len test: ", len(test_data))
248255
print("Tot dataset: ", len(train_data) + len(test_data))
249256

257+
if not os.path.exists(args.save_path):
258+
os.makedirs(args.save_path)
259+
250260
torch.save(train_data, f'{args.save_path}/train_augmented_data_{args.label_kind}.pt')
251-
torch.save(test_data, f'{args.save_path}/test_augmented_data_{args.label_kind}.pt')
261+
torch.save(test_data, f'{args.save_path}/test_data_{args.label_kind}.pt')

data/dataloader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ def __getitem__(self, idx):
1919
sample, label = self.data[idx]
2020
return sample, label
2121

22-
def MyDataLoader(root, label_kind, batch_size, num_workers=1):
22+
def MyDataLoader(train_file, test_file, batch_size, num_workers=1):
2323
print("----Loading dataset----")
2424

25-
training = torch.load(root + f"/train_augmented_data_{label_kind}.pt") # Loads an object saved with torch.save() from a file
26-
validation = torch.load(root + f"/test_augmented_data_{label_kind}.pt") # Loads an object saved with torch.save() from a file
25+
training = torch.load(train_file) # Loads an object saved with torch.save() from a file
26+
validation = torch.load(test_file) # Loads an object saved with torch.save() from a file
2727

2828
train_dataset = MyDataset(training)
2929
eval_dataset = MyDataset(validation)

data/preprocessing.py

+32-26
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import xml.etree.ElementTree as ET
44
import mne
55
import pandas as pd
6-
import numpy as np
6+
import numpy as np
7+
from tqdm import tqdm
78

89
def list_files(directory, sorted_dir):
910
"""List all files (i.e. their paths) in the dataset directory. Need sorted argument
@@ -19,15 +20,8 @@ def list_files(directory, sorted_dir):
1920
files.append(single)
2021
return files
2122

22-
if __name__ == '__main__':
23-
parser = argparse.ArgumentParser()
24-
parser.add_argument('--sessions_path', type=str)
25-
parser.add_argument('--save_path', type=str)
26-
args = parser.parse_args()
27-
28-
sessions_dir = list_files(args.sessions_path, sorted_dir=True)
29-
30-
for dir_id in range(len(sessions_dir)):
23+
def preprocess(sessions_dir, save_path, verbose=False):
24+
for dir_id in tqdm(range(len(sessions_dir)), desc='Preprocessing'):
3125
dir = sessions_dir[dir_id]
3226

3327
# SESSION.XML --------------------------------------------------------------
@@ -43,13 +37,15 @@ def list_files(directory, sorted_dir):
4337
# Get subject infos
4438
session = root.attrib['cutNr']
4539
subject = root[0].attrib['id']
46-
print("\033[94mCurrently considerig sub:{}, session:{}\033[0m".format(subject, session))
47-
40+
41+
if verbose:
42+
print("\033[94mCurrently considerig sub:{}, session:{}\033[0m".format(subject, session))
43+
4844
# PHYSIOLOGICAL DATA -------------------------------------------------------
4945
physio_file = os.path.join(dir, "Part_{}_S_Trial{}_emotion.bdf".format(subject, int(session)//2))
50-
raw = mne.io.read_raw_bdf(physio_file, preload=True)
46+
raw = mne.io.read_raw_bdf(physio_file, preload=True, verbose=verbose)
5147
# documentation for mne raw: https://mne.tools/1.0/auto_tutorials/raw/10_raw_overview.html#sphx-glr-auto-tutorials-raw-10-raw-overview-py
52-
48+
5349
# Get general info from file
5450
n_time_samps = raw.n_times # number of samples
5551
time_secs = raw.times # corresponding second [s] of each sample
@@ -62,7 +58,7 @@ def list_files(directory, sorted_dir):
6258
# Resample all data from 256 Hz to 128 Hz, passing status channels as stimuli
6359
# documentation: https://mne.tools/0.24/auto_tutorials/preprocessing/30_filtering_resampling.html
6460
# OBS: this function applies first a brick-wall filter at the Nyquist frequency of the desired new sampling rate (i.e. 64Hz)
65-
raw = raw.resample(sfreq=128, stim_picks=46)
61+
raw = raw.resample(sfreq=128, stim_picks=46, verbose=verbose)
6662

6763
# Get status channel to extract video's initial and ending samples (to remove baseline pre/post-stimulus)
6864
status_ch, time = raw[-1] # extract last channel
@@ -80,37 +76,37 @@ def list_files(directory, sorted_dir):
8076

8177
# PREPROCESSING
8278
# EEG -------------------------------------------------------------------------------------------
83-
raw_eeg = raw.copy().pick_channels(EEG_CH)
79+
raw_eeg = raw.copy().pick_channels(EEG_CH, verbose=verbose)
8480
# Referencing to average reference
8581
# documentation: https://mne.tools/dev/generated/mne.set_eeg_reference.html
86-
raw_eeg = raw_eeg.set_eeg_reference(ref_channels='average')
82+
raw_eeg = raw_eeg.set_eeg_reference(ref_channels='average', verbose=verbose)
8783
# Artifact removal and filtering
8884
# documentation: https://mne.tools/0.24/auto_tutorials/preprocessing/30_filtering_resampling.html
8985
# Power line at 50 Hz, as proved with plots below
9086
# Band pass FIR filter from 1 - 45 Hz => still need to apply notch filter at 50Hz,
9187
# since the filter is not acting upon the 50Hz component (neglectable attenuation)
92-
raw_eeg = raw_eeg.notch_filter(50)
93-
raw_eeg = raw_eeg.filter(l_freq=1, h_freq=45)
88+
raw_eeg = raw_eeg.notch_filter(50, verbose=verbose)
89+
raw_eeg = raw_eeg.filter(l_freq=1, h_freq=45, verbose=verbose)
9490
# OBS the order between notch and bandpass filter is inrelevant (TRIED)
9591
# EOG removal: not considered for now, TODO?
9692
# ECG -------------------------------------------------------------------------------------------
97-
raw_ecg = raw.copy().pick_channels(ECG_CH)
93+
raw_ecg = raw.copy().pick_channels(ECG_CH, verbose=verbose)
9894
# Artifact removal and filtering
9995
# documentation: https://mne.tools/0.24/auto_tutorials/preprocessing/30_filtering_resampling.html
10096
# Power line at 50 Hz, as proved with plots below
10197
# Band pass FIR filter from 0.5 - 45 Hz => still need to apply notch filter at 50Hz,
10298
# since the filter is not acting upon the 50Hz component (neglectable attenuation)
103-
raw_ecg = raw_ecg.notch_filter(50)
104-
raw_ecg = raw_ecg.filter(l_freq=0.5, h_freq=45)
99+
raw_ecg = raw_ecg.notch_filter(50, verbose=verbose)
100+
raw_ecg = raw_ecg.filter(l_freq=0.5, h_freq=45, verbose=verbose)
105101
# OBS the order between notch and bandpass filter is inrelevant (TRIED)
106102
# GSR -------------------------------------------------------------------------------------------
107-
raw_gsr = raw.copy().pick_channels([GSR_CH])
103+
raw_gsr = raw.copy().pick_channels([GSR_CH], verbose=verbose)
108104
# Artifact removal and filtering
109105
# documentation: https://mne.tools/0.24/auto_tutorials/preprocessing/30_filtering_resampling.html
110106
# Power line at 50 Hz, as proved with plots below
111107
# Low pass FIR filter at 60 Hz => still need to apply notch filter at 50Hz
112-
raw_gsr = raw_gsr.notch_filter(50)
113-
raw_gsr = raw_gsr.filter(l_freq=None, h_freq=60)
108+
raw_gsr = raw_gsr.notch_filter(50, verbose=verbose)
109+
raw_gsr = raw_gsr.filter(l_freq=None, h_freq=60, verbose=verbose)
114110
# OBS the order between notch and bandpass filter is inrelevant (TRIED)
115111

116112
# Extract data (removing baseline pre/post-stimulus)
@@ -227,7 +223,7 @@ def list_files(directory, sorted_dir):
227223
mean_eye_dist.append(-1)
228224

229225
# SAVE CURRENT TRIAL IN NEW DATASET (IN CSV FORMAT) ------------------------
230-
path_name = os.path.join(args.save_path, 'S'+f"{int(subject):02}")
226+
path_name = os.path.join(save_path, 'S'+f"{int(subject):02}")
231227
if not os.path.exists(path_name):
232228
os.makedirs(path_name)
233229

@@ -245,4 +241,14 @@ def list_files(directory, sorted_dir):
245241
f.write(trial_labels[i] + "\n")
246242
f.close()
247243

244+
if __name__ == '__main__':
245+
parser = argparse.ArgumentParser()
246+
parser.add_argument('--sessions_path', type=str, default='hci-tagging-database/Sessions', help='Path to Sessions folder')
247+
parser.add_argument('--save_path', type=str, default='hci-tagging-database/preproc_data', help='Path to save preprocessed data')
248+
parser.add_argument('--verbose', type=bool, action=argparse.BooleanOptionalAction, default=False)
249+
args = parser.parse_args()
250+
251+
sessions_dir = list_files(args.sessions_path, sorted_dir=True)
252+
preprocess(sessions_dir, args.save_path, args.verbose)
253+
248254

earlystopping.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class EarlyStopping:
77
"""Early stops the training if validation metric doesn't improve after a given patience."""
8-
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print, rank=0):
8+
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print, mode='max'):
99
"""
1010
Args:
1111
patience (int): How long to wait after last time validation loss improved.
@@ -17,7 +17,9 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra
1717
path (str): Path for the checkpoint to be saved to.
1818
Default: 'checkpoint.pt'
1919
trace_func (function): trace print function.
20-
Default: print
20+
Default: print
21+
mode (str): 'min' to save model when metric decreases (e.g. loss), 'max' when it increases (e.g. accuracy).
22+
Default: 'max'
2123
"""
2224
self.patience = patience
2325
self.verbose = verbose
@@ -28,30 +30,30 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra
2830
self.delta = delta
2931
self.path = path
3032
self.trace_func = trace_func
31-
self.rank = rank
33+
self.mode = mode
3234

3335
def __call__(self, metric, model):
34-
35-
#score = -val_loss, i.e. in case of a loss its negative value is given
36-
score = metric
36+
if self.mode == 'min':
37+
#Loss saved
38+
score = -metric
39+
else:
40+
score = metric
3741

3842
if self.best_score is None: # initial step
3943
self.best_score = score
40-
# if self.rank == 0:
41-
# self.save_checkpoint(metric, model)
44+
self.save_checkpoint(metric, model)
4245
elif score < self.best_score + self.delta: # if not improving (i.e. not growing)
4346
self.counter += 1
4447
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
4548
if self.counter >= self.patience:
4649
self.early_stop = True
4750
else: # if improved (i.e. grown)
4851
self.best_score = score
49-
# if self.rank == 0:
50-
# self.save_checkpoint(metric, model)
52+
self.save_checkpoint(metric, model)
5153
self.counter = 0
5254

5355
def save_checkpoint(self, metric, model):
54-
'''Saves model when validation loss decrease.'''
56+
'''Saves model when metric imroves.'''
5557
if self.verbose:
5658
if metric > 0:
5759
self.trace_func(f'Validation auc increased ({self.val_metric_min:.6f} --> {metric:.6f}). Saving model ...')

0 commit comments

Comments
 (0)