11
11
12
12
13
13
class CitationNetworks (Dataset ):
14
- def __init__ (self , dataset_dir = DATASET_DIR , directed = False ) -> None :
14
+ def __init__ (self , dataset_dir = DATASET_DIR ) -> None :
15
15
super ().__init__ ()
16
16
17
- self .dataset_name = None # will be defined in child classes
17
+ # will be defined in child classes
18
+ self .dataset_name = None
19
+ self .directed = None
20
+ self .num_features = None
18
21
19
22
self .dataset_dir = dataset_dir
20
- self .directed = directed
21
23
22
- self .num_sample_per_class = 20
24
+ self .num_train_samples_per_class = 20
25
+ self .num_test_samples = 1000
23
26
24
27
def __getitem__ (self , index ):
25
28
return self .X [index ], self .Y [index ]
@@ -28,6 +31,11 @@ def __len__(self):
28
31
return self .num_nodes
29
32
30
33
def preprocess (self ):
34
+ '''
35
+ The preprocess methods are from the following references:
36
+ - http://proceedings.mlr.press/v48/yanga16.pdf
37
+ - https://arxiv.org/pdf/1609.02907.pdf
38
+ '''
31
39
cites_path = os .path .join (
32
40
self .dataset_dir , "{}.cites" .format (self .dataset_name )
33
41
)
@@ -42,21 +50,24 @@ def preprocess(self):
42
50
self .dataset_dir , "{}.content" .format (self .dataset_name )
43
51
)
44
52
45
- col_names = ["Node" ] + list (range (3703 )) + ["Label" ]
53
+ col_names = ["Node" ] + list (range (self . num_features )) + ["Label" ]
46
54
47
55
content_df = pd .read_csv (
48
56
content_path , sep = "\t " , names = col_names , header = None
49
57
)
50
- content_df ["Feature" ] = content_df [range (3703 )].agg (list , axis = 1 )
58
+ content_df ["Feature" ] = content_df [range (self .num_features )]\
59
+ .agg (list , axis = 1 )
51
60
content_df = content_df [["Node" , "Feature" , "Label" ]]
52
61
53
62
node_list = np .array ([str (node ) for node in content_df ["Node" ].values ])
54
63
node2idx = {node : idx for idx , node in enumerate (node_list )}
55
64
num_nodes = node_list .shape [0 ]
56
65
66
+ # Row normalization for the feature matrix
57
67
X = np .array (
58
68
[np .array (feature ) for feature in content_df ["Feature" ].values ]
59
69
)
70
+ X = X / np .sum (X , axis = - 1 , keepdims = True )
60
71
num_feature_maps = X .shape [- 1 ]
61
72
62
73
class_list = np .unique (content_df ["Label" ].values )
@@ -69,16 +80,17 @@ def preprocess(self):
69
80
drop_indices = []
70
81
71
82
for i , row in cites_df .iterrows ():
72
- if row ["To" ] not in node_list or row ["From" ] not in node_list :
83
+ if str (row ["To" ]) not in node_list or \
84
+ str (row ["From" ]) not in node_list :
73
85
drop_indices .append (i )
74
86
75
87
cites_df = cites_df .drop (drop_indices )
76
88
77
89
A = np .zeros ([num_nodes , num_nodes ])
78
90
79
91
for _ , row in cites_df .iterrows ():
80
- to_ = row ["To" ]
81
- from_ = row ["From" ]
92
+ to_ = str ( row ["To" ])
93
+ from_ = str ( row ["From" ])
82
94
83
95
A [node2idx [to_ ], node2idx [from_ ]] = 1
84
96
if not self .directed :
@@ -104,23 +116,69 @@ def preprocess(self):
104
116
105
117
train_indices = np .hstack (
106
118
[
107
- np .random .choice (v , self .num_sample_per_class )
119
+ np .random .choice (v , self .num_train_samples_per_class )
108
120
for _ , v in class2indices .items ()
109
121
]
110
122
)
111
123
test_indices = np .delete (np .arange (num_nodes ), train_indices )
124
+ test_indices = np .random .choice (test_indices , self .num_test_samples )
112
125
113
126
return A , A_hat , X , Y , node_list , node2idx , num_nodes , \
114
127
num_feature_maps , class_list , class2idx , num_classes , \
115
128
class2indices , train_indices , test_indices
116
129
117
130
118
131
class Citeseer (CitationNetworks ):
119
- def __init__ (self ) -> None :
132
+ def __init__ (self , directed ) -> None :
120
133
super ().__init__ ()
121
134
135
+ self .directed = directed
136
+
137
+ self .num_features = 3703
138
+
122
139
self .dataset_name = "citeseer"
123
140
self .dataset_dir = os .path .join (self .dataset_dir , self .dataset_name )
141
+ if self .directed :
142
+ self .preprocessed_dir = os .path .join (
143
+ self .dataset_dir , "directed"
144
+ )
145
+ else :
146
+ self .preprocessed_dir = os .path .join (
147
+ self .dataset_dir , "undirected"
148
+ )
149
+ print (self .preprocessed_dir )
150
+
151
+ if not os .path .exists (self .preprocessed_dir ):
152
+ os .mkdir (self .preprocessed_dir )
153
+
154
+ if os .path .exists (os .path .join (self .preprocessed_dir , "dataset.pkl" )):
155
+ with open (
156
+ os .path .join (self .preprocessed_dir , "dataset.pkl" ), "rb"
157
+ ) as f :
158
+ dataset = pickle .load (f )
159
+ else :
160
+ dataset = self .preprocess ()
161
+ with open (
162
+ os .path .join (self .preprocessed_dir , "dataset.pkl" ), "wb"
163
+ ) as f :
164
+ pickle .dump (dataset , f )
165
+
166
+ self .A , self .A_hat , self .X , self .Y , self .node_list , self .node2idx , \
167
+ self .num_nodes , self .num_feature_maps , self .class_list , \
168
+ self .class2idx , self .num_classes , self .class2indices , \
169
+ self .train_indices , self .test_indices = dataset
170
+
171
+
172
+ class Cora (CitationNetworks ):
173
+ def __init__ (self , directed ) -> None :
174
+ super ().__init__ ()
175
+
176
+ self .directed = directed
177
+
178
+ self .num_features = 1433
179
+
180
+ self .dataset_name = "cora"
181
+ self .dataset_dir = os .path .join (self .dataset_dir , self .dataset_name )
124
182
if self .directed :
125
183
self .preprocessed_dir = os .path .join (
126
184
self .dataset_dir , "directed"
0 commit comments