forked from graphdeeplearning/benchmarking-gnns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
48 lines (39 loc) · 1.39 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
File to load dataset based on user control from main file
"""
from data.superpixels import SuperPixDataset
from data.molecules import MoleculeDataset
from data.TUs import TUsDataset
from data.SBMs import SBMsDataset
from data.TSP import TSPDataset
from data.COLLAB import COLLABDataset
from data.CSL import CSLDataset
def LoadData(DATASET_NAME):
"""
This function is called in the main.py file
returns:
; dataset object
"""
# handling for MNIST or CIFAR Superpixels
if DATASET_NAME == 'MNIST' or DATASET_NAME == 'CIFAR10':
return SuperPixDataset(DATASET_NAME)
# handling for (ZINC) molecule dataset
if DATASET_NAME == 'ZINC':
return MoleculeDataset(DATASET_NAME)
# handling for the TU Datasets
TU_DATASETS = ['ENZYMES', 'DD', 'PROTEINS_full']
if DATASET_NAME in TU_DATASETS:
return TUsDataset(DATASET_NAME)
# handling for SBM datasets
SBM_DATASETS = ['SBM_CLUSTER', 'SBM_PATTERN']
if DATASET_NAME in SBM_DATASETS:
return SBMsDataset(DATASET_NAME)
# handling for TSP dataset
if DATASET_NAME == 'TSP':
return TSPDataset(DATASET_NAME)
# handling for COLLAB dataset
if DATASET_NAME == 'OGBL-COLLAB':
return COLLABDataset(DATASET_NAME)
# handling for the CSL (Circular Skip Links) Dataset
if DATASET_NAME == 'CSL':
return CSLDataset(DATASET_NAME)