@@ -37,6 +37,12 @@ def _load_memory_mapped_array(file_name):
37
37
parser .add_argument (
38
38
"--num_nodes" , type = int , default = 8 , help = "Number of nodes in the graph"
39
39
)
40
+ parser .add_argument (
41
+ "--num_workers" ,
42
+ type = int ,
43
+ default = 4 ,
44
+ help = "Number of epochs used to train" ,
45
+ )
40
46
parser .add_argument (
41
47
"--gpu_idx" ,
42
48
type = int ,
@@ -97,6 +103,7 @@ def _load_memory_mapped_array(file_name):
97
103
_EXPERIMENT_NAME = args .exp_name
98
104
_BATCH_SIZE = args .batch_size
99
105
num_feats = args .num_feats
106
+ num_workers = args .num_workers
100
107
101
108
# set up input and targets from files
102
109
x = _load_memory_mapped_array (f"psd_features_data_X" )
@@ -149,7 +156,6 @@ def _load_memory_mapped_array(file_name):
149
156
# Dataloader========================================================================================================
150
157
151
158
# use WeightedRandomSampler to balance the training dataset
152
- NUM_WORKERS = 4
153
159
154
160
labels_unique , counts = np .unique (y , return_counts = True )
155
161
@@ -172,7 +178,7 @@ def _load_memory_mapped_array(file_name):
172
178
dataset = train_dataset ,
173
179
batch_size = _BATCH_SIZE ,
174
180
sampler = weighted_sampler ,
175
- num_workers = NUM_WORKERS ,
181
+ num_workers = num_workers ,
176
182
pin_memory = True ,
177
183
)
178
184
@@ -181,7 +187,7 @@ def _load_memory_mapped_array(file_name):
181
187
dataset = train_dataset ,
182
188
batch_size = _BATCH_SIZE ,
183
189
shuffle = False ,
184
- num_workers = NUM_WORKERS ,
190
+ num_workers = num_workers ,
185
191
pin_memory = True ,
186
192
)
187
193
@@ -194,7 +200,7 @@ def _load_memory_mapped_array(file_name):
194
200
dataset = test_dataset ,
195
201
batch_size = _BATCH_SIZE ,
196
202
shuffle = False ,
197
- num_workers = NUM_WORKERS ,
203
+ num_workers = num_workers ,
198
204
pin_memory = True ,
199
205
)
200
206
0 commit comments