Skip to content

Commit 04a9f8d

Browse files
committed
completed gcn.py
1 parent 5e7cfa2 commit 04a9f8d

File tree

4 files changed

+166
-47
lines changed

4 files changed

+166
-47
lines changed

config.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"train_config": {
3-
"num_epochs": 400,
3+
"num_epochs": 200,
44
"learning_rate": 0.01
55
},
66
"model_config": {

data_loaders/citation_networks.py

+69-11
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111

1212

1313
class CitationNetworks(Dataset):
14-
def __init__(self, dataset_dir=DATASET_DIR, directed=False) -> None:
14+
def __init__(self, dataset_dir=DATASET_DIR) -> None:
1515
super().__init__()
1616

17-
self.dataset_name = None # will be defined in child classes
17+
# will be defined in child classes
18+
self.dataset_name = None
19+
self.directed = None
20+
self.num_features = None
1821

1922
self.dataset_dir = dataset_dir
20-
self.directed = directed
2123

22-
self.num_sample_per_class = 20
24+
self.num_train_samples_per_class = 20
25+
self.num_test_samples = 1000
2326

2427
def __getitem__(self, index):
2528
return self.X[index], self.Y[index]
@@ -28,6 +31,11 @@ def __len__(self):
2831
return self.num_nodes
2932

3033
def preprocess(self):
34+
'''
35+
The preprocess methods are from the following references:
36+
- http://proceedings.mlr.press/v48/yanga16.pdf
37+
- https://arxiv.org/pdf/1609.02907.pdf
38+
'''
3139
cites_path = os.path.join(
3240
self.dataset_dir, "{}.cites".format(self.dataset_name)
3341
)
@@ -42,21 +50,24 @@ def preprocess(self):
4250
self.dataset_dir, "{}.content".format(self.dataset_name)
4351
)
4452

45-
col_names = ["Node"] + list(range(3703)) + ["Label"]
53+
col_names = ["Node"] + list(range(self.num_features)) + ["Label"]
4654

4755
content_df = pd.read_csv(
4856
content_path, sep="\t", names=col_names, header=None
4957
)
50-
content_df["Feature"] = content_df[range(3703)].agg(list, axis=1)
58+
content_df["Feature"] = content_df[range(self.num_features)]\
59+
.agg(list, axis=1)
5160
content_df = content_df[["Node", "Feature", "Label"]]
5261

5362
node_list = np.array([str(node) for node in content_df["Node"].values])
5463
node2idx = {node: idx for idx, node in enumerate(node_list)}
5564
num_nodes = node_list.shape[0]
5665

66+
# Row normalization for the feature matrix
5767
X = np.array(
5868
[np.array(feature) for feature in content_df["Feature"].values]
5969
)
70+
X = X / np.sum(X, axis=-1, keepdims=True)
6071
num_feature_maps = X.shape[-1]
6172

6273
class_list = np.unique(content_df["Label"].values)
@@ -69,16 +80,17 @@ def preprocess(self):
6980
drop_indices = []
7081

7182
for i, row in cites_df.iterrows():
72-
if row["To"] not in node_list or row["From"] not in node_list:
83+
if str(row["To"]) not in node_list or \
84+
str(row["From"]) not in node_list:
7385
drop_indices.append(i)
7486

7587
cites_df = cites_df.drop(drop_indices)
7688

7789
A = np.zeros([num_nodes, num_nodes])
7890

7991
for _, row in cites_df.iterrows():
80-
to_ = row["To"]
81-
from_ = row["From"]
92+
to_ = str(row["To"])
93+
from_ = str(row["From"])
8294

8395
A[node2idx[to_], node2idx[from_]] = 1
8496
if not self.directed:
@@ -104,23 +116,69 @@ def preprocess(self):
104116

105117
train_indices = np.hstack(
106118
[
107-
np.random.choice(v, self.num_sample_per_class)
119+
np.random.choice(v, self.num_train_samples_per_class)
108120
for _, v in class2indices.items()
109121
]
110122
)
111123
test_indices = np.delete(np.arange(num_nodes), train_indices)
124+
test_indices = np.random.choice(test_indices, self.num_test_samples)
112125

113126
return A, A_hat, X, Y, node_list, node2idx, num_nodes, \
114127
num_feature_maps, class_list, class2idx, num_classes, \
115128
class2indices, train_indices, test_indices
116129

117130

118131
class Citeseer(CitationNetworks):
119-
def __init__(self) -> None:
132+
def __init__(self, directed) -> None:
120133
super().__init__()
121134

135+
self.directed = directed
136+
137+
self.num_features = 3703
138+
122139
self.dataset_name = "citeseer"
123140
self.dataset_dir = os.path.join(self.dataset_dir, self.dataset_name)
141+
if self.directed:
142+
self.preprocessed_dir = os.path.join(
143+
self.dataset_dir, "directed"
144+
)
145+
else:
146+
self.preprocessed_dir = os.path.join(
147+
self.dataset_dir, "undirected"
148+
)
149+
print(self.preprocessed_dir)
150+
151+
if not os.path.exists(self.preprocessed_dir):
152+
os.mkdir(self.preprocessed_dir)
153+
154+
if os.path.exists(os.path.join(self.preprocessed_dir, "dataset.pkl")):
155+
with open(
156+
os.path.join(self.preprocessed_dir, "dataset.pkl"), "rb"
157+
) as f:
158+
dataset = pickle.load(f)
159+
else:
160+
dataset = self.preprocess()
161+
with open(
162+
os.path.join(self.preprocessed_dir, "dataset.pkl"), "wb"
163+
) as f:
164+
pickle.dump(dataset, f)
165+
166+
self.A, self.A_hat, self.X, self.Y, self.node_list, self.node2idx, \
167+
self.num_nodes, self.num_feature_maps, self.class_list, \
168+
self.class2idx, self.num_classes, self.class2indices, \
169+
self.train_indices, self.test_indices = dataset
170+
171+
172+
class Cora(CitationNetworks):
173+
def __init__(self, directed) -> None:
174+
super().__init__()
175+
176+
self.directed = directed
177+
178+
self.num_features = 1433
179+
180+
self.dataset_name = "cora"
181+
self.dataset_dir = os.path.join(self.dataset_dir, self.dataset_name)
124182
if self.directed:
125183
self.preprocessed_dir = os.path.join(
126184
self.dataset_dir, "directed"

models/gcn.py

+83-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import os
2+
3+
import numpy as np
14
import torch
25

36
from torch.nn import Module, Linear, Dropout
7+
from torch.nn.init import xavier_normal_
8+
from torch.nn.functional import cross_entropy
49
from torch.sparse import mm
510
from torch.optim import Adam
611

@@ -35,42 +40,111 @@ def __init__(self, A_hat, C, H, F, num_layers, dropout, regularization):
3540
self.W0 = Linear(self.C, self.H, bias=False)
3641
self.W1 = Linear(self.H, self.F, bias=False)
3742

43+
xavier_normal_(self.W0.weight)
44+
xavier_normal_(self.W1.weight)
45+
3846
self.Wh = [
3947
Linear(self.H, self.H, bias=False)
4048
for _ in range(self.num_layers - 2)
4149
]
4250

51+
for Wh in self.Wh:
52+
xavier_normal_(Wh.weight)
53+
4354
self.dropout_layer = Dropout(self.dropout)
4455

4556
self.L2 = torch.sum(
4657
FloatTensor([torch.norm(param) for param in self.W0.parameters()])
4758
)
4859

49-
def forward(self, X):
60+
def get_logits(self, X):
5061
Z = self.dropout_layer(torch.relu(mm(self.A_hat, self.W0(X))))
5162
for Wh in self.Wh:
5263
Z = torch.relu(mm(self.A_hat, Wh(Z)))
53-
Z = self.dropout_layer(
54-
torch.softmax(mm(self.A_hat, self.W1(Z)), dim=-1)
55-
)
64+
Z = self.dropout_layer(mm(self.A_hat, self.W1(Z)))
5665

5766
return Z
5867

68+
def forward(self, X):
69+
Z = self.get_logits(X)
70+
71+
return torch.softmax(Z, dim=-1)
72+
5973
def train_model(
60-
self, num_epochs, learning_rate, dataset, train_indices, test_indices
74+
self, num_epochs, learning_rate, dataset, train_indices, test_indices,
75+
ckpt_path
6176
):
77+
accs = []
78+
train_losses = []
79+
test_losses = []
80+
81+
max_acc = 0
82+
6283
opt = Adam(self.parameters(), learning_rate)
6384

6485
X = FloatTensor(dataset.X)
6586

6687
for i in range(1, num_epochs + 1):
67-
self.eval()
88+
self.train()
6889

6990
_, Y = dataset[train_indices]
70-
Y = FloatTensor(Y)
91+
Y = LongTensor(Y)
7192

7293
Z = torch.gather(
73-
self(X), dim=0, index=LongTensor(train_indices).unsqueeze(-1).repeat(1, self.F)
94+
self.get_logits(X),
95+
dim=0,
96+
index=LongTensor(train_indices)
97+
.unsqueeze(-1).repeat(1, self.F)
7498
)
7599

76-
print(self(X).shape, Z.shape, train_indices.shape, self(X)[train_indices[0]] == Z[0])
100+
opt.zero_grad()
101+
train_loss = cross_entropy(Z, Y)
102+
(train_loss + self.regularization * self.L2).backward()
103+
opt.step()
104+
105+
train_loss = train_loss.detach().cpu().numpy()
106+
107+
train_losses.append(train_loss)
108+
109+
with torch.no_grad():
110+
self.eval()
111+
112+
_, Y = dataset[test_indices]
113+
Y = LongTensor(Y)
114+
115+
Z = torch.gather(
116+
self.get_logits(X),
117+
dim=0,
118+
index=LongTensor(test_indices)
119+
.unsqueeze(-1).repeat(1, self.F)
120+
)
121+
122+
test_loss = cross_entropy(Z, Y)
123+
test_loss = test_loss.detach().cpu().numpy()
124+
125+
test_losses.append(test_loss)
126+
127+
Y = Y.detach().cpu().numpy()
128+
129+
Z = torch.softmax(Z, dim=-1).detach().cpu().numpy()
130+
Z = np.argmax(Z, axis=-1)
131+
132+
acc = np.mean(Y == Z)
133+
134+
accs.append(acc)
135+
136+
print(
137+
"Epoch: {}, Train Loss: {}, Test Loss: {}, Test ACC: {}"
138+
.format(i, train_loss, test_loss, acc)
139+
)
140+
141+
if acc > max_acc:
142+
torch.save(
143+
self.state_dict(),
144+
os.path.join(
145+
ckpt_path, "model.ckpt"
146+
)
147+
)
148+
max_acc = acc
149+
150+
return accs, train_losses, test_losses

train.py

+13-26
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
import torch
77

8-
from torch.utils.data import DataLoader, random_split
9-
10-
from data_loaders.citation_networks import Citeseer
8+
from data_loaders.citation_networks import Citeseer, Cora
119

1210
from models.gcn import GCN
1311

@@ -44,7 +42,9 @@ def main(dataset_name, directed):
4442
regularization = model_config["regularization"]
4543

4644
if dataset_name == "citeseer":
47-
dataset = Citeseer()
45+
dataset = Citeseer(directed=directed)
46+
elif dataset_name == "cora":
47+
dataset = Cora(directed=directed)
4848

4949
if torch.cuda.is_available():
5050
device = "cuda"
@@ -59,32 +59,19 @@ def main(dataset_name, directed):
5959
model = GCN(
6060
dataset.A_hat, dataset.num_feature_maps, H, dataset.num_classes,
6161
num_layers, dropout, regularization
62-
)
62+
).to(device)
6363

64-
model.train_model(
64+
accs, train_losses, test_losses = model.train_model(
6565
num_epochs, learning_rate, dataset, dataset.train_indices,
66-
dataset.test_indices
66+
dataset.test_indices, ckpt_path
6767
)
6868

69-
# train_size = dataset.train_indices.shape[0]
70-
# test_size = dataset.test_indices.shape[0]
71-
72-
# train_dataset, test_dataset = random_split(
73-
# dataset, [train_size, test_size]
74-
# )
75-
76-
# train_dataset.indices = dataset.train_indices
77-
# test_dataset.indices = dataset.test_indices
78-
79-
# train_loader = DataLoader(
80-
# train_dataset, batch_size=train_size, shuffle=False
81-
# )
82-
# test_loader = DataLoader(
83-
# test_dataset, batch_size=test_size, shuffle=False
84-
# )
85-
86-
# print(train_dataset.indices)
87-
# print(train_loader.indices)
69+
with open(os.path.join(ckpt_path, "accs.pkl"), "wb") as f:
70+
pickle.dump(accs, f)
71+
with open(os.path.join(ckpt_path, "train_losses.pkl"), "wb") as f:
72+
pickle.dump(train_losses, f)
73+
with open(os.path.join(ckpt_path, "test_losses.pkl"), "wb") as f:
74+
pickle.dump(test_losses, f)
8875

8976

9077
if __name__ == "__main__":

0 commit comments

Comments
 (0)