Skip to content

Commit 46bc0b6

Browse files
drivanovUbuntu
authored and
Ubuntu
committed
Adding --num_workers input parameter to the EEG_GCNN example. (dmlc#6467)
1 parent 52243a7 commit 46bc0b6

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

examples/pytorch/eeg-gcnn/main.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def _load_memory_mapped_array(file_name):
3737
parser.add_argument(
3838
"--num_nodes", type=int, default=8, help="Number of nodes in the graph"
3939
)
40+
parser.add_argument(
41+
"--num_workers",
42+
type=int,
43+
default=4,
44+
help="Number of epochs used to train",
45+
)
4046
parser.add_argument(
4147
"--gpu_idx",
4248
type=int,
@@ -97,6 +103,7 @@ def _load_memory_mapped_array(file_name):
97103
_EXPERIMENT_NAME = args.exp_name
98104
_BATCH_SIZE = args.batch_size
99105
num_feats = args.num_feats
106+
num_workers = args.num_workers
100107

101108
# set up input and targets from files
102109
x = _load_memory_mapped_array(f"psd_features_data_X")
@@ -149,7 +156,6 @@ def _load_memory_mapped_array(file_name):
149156
# Dataloader========================================================================================================
150157

151158
# use WeightedRandomSampler to balance the training dataset
152-
NUM_WORKERS = 4
153159

154160
labels_unique, counts = np.unique(y, return_counts=True)
155161

@@ -172,7 +178,7 @@ def _load_memory_mapped_array(file_name):
172178
dataset=train_dataset,
173179
batch_size=_BATCH_SIZE,
174180
sampler=weighted_sampler,
175-
num_workers=NUM_WORKERS,
181+
num_workers=num_workers,
176182
pin_memory=True,
177183
)
178184

@@ -181,7 +187,7 @@ def _load_memory_mapped_array(file_name):
181187
dataset=train_dataset,
182188
batch_size=_BATCH_SIZE,
183189
shuffle=False,
184-
num_workers=NUM_WORKERS,
190+
num_workers=num_workers,
185191
pin_memory=True,
186192
)
187193

@@ -194,7 +200,7 @@ def _load_memory_mapped_array(file_name):
194200
dataset=test_dataset,
195201
batch_size=_BATCH_SIZE,
196202
shuffle=False,
197-
num_workers=NUM_WORKERS,
203+
num_workers=num_workers,
198204
pin_memory=True,
199205
)
200206

0 commit comments

Comments
 (0)