-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGNN.py
263 lines (219 loc) · 10.7 KB
/
GNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import os
from datetime import datetime
import json
import numpy as np
import random
import torch
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from typing import Optional, List
class StaticGNN(torch.nn.Module):
"""
A Graph Convolutional Network (GNN) model with static features.
This class encapsulates the creation, training, and inference stages of a GNN model. The model
uses Mean Squared Error (MSE) as the loss function and Adam as the optimizer. It is specifically
designed for static features on nodes.
Attributes:
device: The device to run the model on (cpu or gpu).
lr: Learning rate for the optimizer.
batch_size: Batch size for training.
l2_reg: L2 regularization strength.
weights: Weights for the output dimensions.
params: A dictionary storing various hyperparameters and configurations.
losses: A list to store training losses for each epoch.
"""
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
device: torch.device,
num_hidden_layers: int = 1,
batch_size: int = 32,
l2_reg: float = 0.0,
lr: float = 0.001,
weights: Optional[List[float]] = None,
random_seed: Optional[int] = None,
message: Optional[str] = None):
"""
Initialize the StaticGNN model.
Args:
input_dim (int): Dimensionality of the input features.
hidden_dim (int): Dimensionality of the hidden layers.
output_dim (int): Dimensionality of the output layer.
device (torch.device): Device to run the model on (e.g., 'cuda' or 'cpu').
batch_size (int, optional): Batch size for training. Defaults to 32.
l2_reg (float, optional): L2 regularization strength. Defaults to 0.0.
lr (float, optional): Learning rate for the optimizer. Defaults to 0.001.
weights (Optional[List[float]], optional): Weights for the output dimensions. Defaults to None.
random_seed (Optional[int], optional): Random seed for reproducibility. Defaults to None.
message (Optional[str], optional): Custom message for the model. Defaults to None.
"""
super(StaticGNN, self).__init__()
self.device = device
self.num_hidden_layers = num_hidden_layers
self.lr = lr
self.batch_size = batch_size
self.l2_reg = l2_reg
if weights is None:
self.weights = torch.ones(output_dim).to(self.device)
else:
assert len(weights) == output_dim, "Length of weights must be equal to output_dim"
self.weights = torch.tensor(weights).to(self.device)
# layers
self.convs = torch.nn.ModuleList([SAGEConv(input_dim, hidden_dim)])
for _ in range(num_hidden_layers):
self.convs.append(torch.nn.Linear(hidden_dim, hidden_dim))
self.convs.append(SAGEConv(hidden_dim, hidden_dim))
self.fc = torch.nn.Linear(hidden_dim, output_dim)
# Model utilities
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=l2_reg)
self.criterion = torch.nn.MSELoss(reduction='none')
self.device = device
self.losses = []
self.states = []
self.params = {'model': 'GNN',
'input_dim': input_dim,
'hidden_dim': hidden_dim,
'output_dim': output_dim,
'device': str(device),
'num_hidden_layers': num_hidden_layers,
'lr': lr,
'batch_size': batch_size,
'l2_reg': l2_reg,
'weights': weights,
'random_seed': random_seed,
'message': message}
self.to(device)
# Set the random seed if it's provided
if random_seed is not None:
os.environ['PYTHONHASHSEED']=str(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
def forward(self, data):
x, edge_index = data.x.to(self.device), data.edge_index.to(self.device)
# First SageConv Layer (Outside of loop)
x = self.convs[0](x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
# Using loop for subsequent Linear + SAGEConv layers
for i in range(1, len(self.convs) - 1, 2): # step by 2 because we have two layers in each iteration
# Fully Connected Layer
x = self.convs[i](x)
x = F.relu(x)
x = F.dropout(x, training=self.training)
# SAGEConv Layer
x = self.convs[i+1](x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
# Final fully connected layer
x = global_mean_pool(x, batch=data.batch)
x = self.fc(x)
return x
def train(self, X_train, edge_index, y_train, X_val=None, y_val=None, epochs=100, use_tqdm=True, save_loss=False):
"""
Train the StaticGNN model.
Args:
X_train (np.array): Training node features.
edge_index (np.array): Edge index for training graph.
y_train (np.array): Training target values.
X_val (Optional[np.array], optional): Validation node features. Defaults to None.
y_val (Optional[np.array], optional): Validation target values. Defaults to None.
epochs (int, optional): Number of training epochs. Defaults to 100.
use_tqdm (bool, optional): Whether to display a tqdm progress bar. Defaults to True.
save_loss (bool, optional): Whether to save the training loss. Defaults to False.
"""
super().train()
# Data preparation for training
X_train = torch.tensor(X_train, dtype=torch.float)
edge_index = torch.tensor(edge_index, dtype=torch.long)
y_train = torch.tensor(y_train, dtype=torch.float)
train_data_list = [Data(x=X_train[i], edge_index=edge_index, y=y_train[i].unsqueeze(
0)) for i in range(X_train.shape[0])]
loader = DataLoader(
train_data_list, batch_size=self.batch_size, shuffle=True, pin_memory=True)
# Data preparation for validation if provided
if X_val is not None and y_val is not None:
X_val = torch.tensor(X_val, dtype=torch.float)
y_val = torch.tensor(y_val, dtype=torch.float)
val_data_list = [Data(x=X_val[i], edge_index=edge_index, y=y_val[i].unsqueeze(
0)) for i in range(X_val.shape[0])]
val_loader = DataLoader(val_data_list, batch_size=self.batch_size * 10)
if use_tqdm:
from tqdm import tqdm
pbar = tqdm(total=epochs*len(loader),
desc="Training Progress")
for epoch in range(epochs):
train_loss = np.zeros(y_train.shape[1])
for data in loader:
data = data.to(self.device)
self.optimizer.zero_grad()
out = self(data)
loss = self.criterion(out, data.y) * self.weights
loss.mean().backward()
self.optimizer.step()
train_loss += np.mean(loss.detach().cpu().numpy(), axis=0)
if use_tqdm:
pbar.set_description(f"Epoch {epoch}")
pbar.update()
if use_tqdm:
pbar.set_postfix({'loss': loss.mean().item()})
if save_loss:
if X_val is not None and y_val is not None:
val_losses = np.zeros(y_val.shape[1])
for val_data in val_loader:
val_data = val_data.to(self.device)
with torch.no_grad():
val_out = self(val_data)
val_loss = self.criterion(val_out, val_data.y)
val_loss = np.mean(val_loss.detach().cpu().numpy(), axis=0)
val_losses += val_loss
losses = val_losses/len(val_loader)
mean_loss = float(np.mean(losses))
else:
losses = train_loss/len(loader)
mean_loss = float(np.mean(losses))
self.losses.append((*losses.tolist(), mean_loss))
self.states.append(self.state_dict())
if use_tqdm:
pbar.close()
if save_loss:
self.losses = [self._moving_average(loss).tolist() for loss in np.transpose(self.losses)]
self.params['losses'] = self.losses
self.params['best_epoch'] = int(np.argmin(self.losses[-1]))
# Get the current time and format it as a string
now = datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")
# Save the parameters and losses with a timestamp in the filename
with open(f'results/losses_{timestamp}.json', 'w') as f:
json.dump(self.params, f)
torch.save(self.states[self.params['best_epoch']], f'models/best_model_{timestamp}.pth')
def predict(self, X_test, edge_index, y_test):
"""
Predict with the trained StaticGNN model.
Args:
X_test (np.array): Test node features.
edge_index (np.array): Edge index for test graph.
y_test (np.array): Test target values.
Returns:
np.array: Predicted values for the test data.
"""
super().train(False)
X_test = torch.tensor(X_test, dtype=torch.float)
edge_index = torch.tensor(edge_index, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.float)
test_data_list = [Data(x=X_test[i], edge_index=edge_index, y=y_test[i])
for i in range(X_test.shape[0])]
test_loader = DataLoader(test_data_list, batch_size=self.batch_size, pin_memory=True)
predictions = np.zeros((0, 3))
with torch.no_grad():
for data in test_loader:
data = data.to(self.device)
pred = self(data)
predictions = np.vstack((predictions, pred.cpu().numpy()))
return predictions
def _moving_average(self, data, window_size=5):
""" Compute moving average using numpy. """
weights = np.ones(window_size) / window_size
return np.convolve(data, weights, mode='valid')