Skip to content

Commit

Permalink
update gad-nr code
Browse files Browse the repository at this point in the history
  • Loading branch information
YingtongDou committed Aug 14, 2023
1 parent 07cec20 commit fecec02
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 115 deletions.
33 changes: 12 additions & 21 deletions pygod/detector/gadnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,13 @@ def __init__(self,
gpu=-1,
batch_size=0,
num_neigh=-1,
graphlet_size=4,
selected_motif=True,
cache_dir=None,
verbose=0,
save_emb=False,
compile_model=False,
**kwargs):

if backbone is not None:
warnings.warn("Backbone is not used in GUIDE")

super(GUIDE, self).__init__(hid_dim=(hid_a, hid_s),
super(GADNR, self).__init__(hid_dim=(hid_a, hid_s),
num_layers=num_layers,
dropout=dropout,
weight_decay=weight_decay,
Expand Down Expand Up @@ -172,20 +167,16 @@ def init_model(self, **kwargs):

def forward_model(self, data):

batch_size = data.batch_size

x = data.x.to(self.device)
s = data.s.to(self.device)
edge_index = data.edge_index.to(self.device)

x_, s_ = self.model(x, s, edge_index)

score = self.model.loss_func(x[:batch_size],
x_[:batch_size],
s[:batch_size],
s_[:batch_size],
self.alpha)
l1, h0 = self.model(data.x, data.edge_index)

losses = self.model.loss_func(l1,
ground_truth_degree_matrix,
h0,
neighbor_dict,
device,
data.x,
data.edge_index)

loss = torch.mean(score)
loss, loss_per_node, h_loss, degree_loss, feature_loss = losses

return loss, score.detach().cpu()
return loss, loss_per_node,h_loss,degree_loss,feature_loss
4 changes: 2 additions & 2 deletions pygod/nn/decoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""Graph Decoders"""
# Author: Kay Liu <[email protected]>
# Author: Kay Liu <[email protected]>, Yingtong Dou <[email protected]>
# License: BSD 2 clause

import torch
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self,
**kwargs)

def forward(self, x, edge_index):
"""
r"""
Forward computation.
Parameters
Expand Down
108 changes: 28 additions & 80 deletions pygod/nn/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,47 @@
# License: BSD 2 clause

import torch
import torch.nn as nn
import torch.nn.functional as F

from .conv import GNAConv


class GNA(torch.nn.Module):
class GNA(nn.Module):
"""
Graph Node Attention Network (GNA). See :cite:`yuan2021higher` for
more details.
Parameters
----------
in_dim : int
Input dimension of node features.
hid_dim : int
Hidden dimension of the model.
num_layers : int
Number of layers in the model.
out_dim : int
Output dimension of the model.
dropout : float, optional
Dropout rate. Default: ``0.``.
act : callable activation function or None, optional
Activation function if not None.
Default: ``torch.nn.functional.relu``.
"""
def __init__(self,
in_channels,
hidden_channels,
in_dim,
hid_dim,
num_layers,
out_channels,
dropout,
act):
out_dim,
dropout=0.,
act=torch.nn.functional.relu):
super().__init__()
self.layers = torch.nn.ModuleList()
self.layers.append(GNAConv(in_channels, hidden_channels))
self.layers = nn.ModuleList()
self.layers.append(GNAConv(in_dim, hid_dim))
for layer in range(num_layers - 2):
self.layers.append(GNAConv(hidden_channels,
hidden_channels))
self.layers.append(GNAConv(hidden_channels, out_channels))
self.layers.append(GNAConv(hid_dim,
hid_dim))
self.layers.append(GNAConv(hid_dim, out_dim))

self.dropout = dropout
self.act = act
Expand All @@ -54,72 +71,3 @@ def forward(self, s, edge_index):
if self.act is not None:
s = self.act(s)
return s


class MLP_GAD_NR(torch.nn.Module):
"""
The personalized MLP module used by GAD_NR
Source: https://github.com/Graph-COM/GAD-NR/blob/master/GAD-NR.ipynb
"""
def __init__(self, num_layers, input_dim, hid_dim, output_dim):
super(MLP_GAD_NR, self).__init__()

self.linear_or_not = True # default is linear model
self.num_layers = num_layers

if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()

self.linears.append(nn.Linear(input_dim, hid_dim))
for layer in range(num_layers - 2):
self.linears.append(nn.Linear(hid_dim, hid_dim))
self.linears.append(nn.Linear(hid_dim, output_dim))

for layer in range(num_layers - 1):
self.batch_norms.append(nn.BatchNorm1d((hid_dim)))

def forward(self, x):
"""
Forward computation.
Parameters
----------
x : torch.Tensor
Input node features.
Returns
-------
h : torch.Tensor
Transformed node feature embeddings.
"""
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for layer in range(self.num_layers - 1):
h = self.linears[layer](h)

if len(h.shape) > 2:
h = torch.transpose(h, 0, 1)
h = torch.transpose(h, 1, 2)

h = self.batch_norms[layer](h)

if len(h.shape) > 2:
h = torch.transpose(h, 1, 2)
h = torch.transpose(h, 0, 1)

h = F.relu(h)
h = self.linears[self.num_layers - 1](h)

return h
124 changes: 112 additions & 12 deletions pygod/nn/gadnr.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
import random

import torch
import torch.nn as nn
from torch_geometric.nn import GIN
import torch.multiprocessing as mp
from torch_geometric.nn import GIN, SAGEConv, PNAConv
from torch_geometric.utils import to_dense_adj

from .encoder import MLP_GAD_NR
from .decoder import DotProductDecoder
from .nn import MLP_GAD_NR, MLP_generator, FNN_GAD_NR
from .functional import double_recon_loss


Expand Down Expand Up @@ -51,7 +53,7 @@ def __init__(self,
hid_dim=64,
num_layers=2,
sample_size=2,
neibor_num_list=None,
neighbor_num_list=None,
lambda_loss1=1e-2,
lambda_loss2=1e-3,
lambda_loss3=1e-4,
Expand Down Expand Up @@ -79,10 +81,10 @@ def __init__(self,

self.gaussian_mean = nn.Parameter(
torch.FloatTensor(sample_size, hid_dim).uniform_(-0.5 / hid_dim,
0.5 / hid_dim)).to(device)
0.5 / hid_dim)).to(device)
self.gaussian_log_sigma = nn.Parameter(
torch.FloatTensor(sample_size, hid_dim).uniform_(-0.5 / hid_dim,
0.5 / hid_dim)).to(device)
0.5 / hid_dim)).to(device)
self.m = torch.distributions.Normal(torch.zeros(sample_size, hid_dim),
torch.ones(sample_size, hid_dim))

Expand All @@ -100,16 +102,15 @@ def __init__(self,
torch.FloatTensor(hid_dim).uniform_(-0.5 / hid_dim, 0.5 / hid_dim)).to(device)
self.mlp_m = torch.distributions.Normal(torch.zeros(hid_dim), torch.ones(hid_dim))

self.mlp_mean = FNN(hid_dim, hid_dim, hid_dim, 3)
self.mlp_sigma = FNN(hid_dim, hid_dim, hid_dim, 3)
self.mlp_mean = FNN_GAD_NR(hid_dim, hid_dim, hid_dim, 3)
self.mlp_sigma = FNN_GAD_NR(hid_dim, hid_dim, hid_dim, 3)
self.softplus = nn.Softplus()

self.mean_agg = SAGEConv(hid_dim, hid_dim, aggr='mean', normalize = False)
# self.mean_agg = GraphSAGE(hid_dim, hid_dim, aggr='mean', num_layers=1)
self.std_agg = PNAConv(hid_dim, hid_dim, aggregators=["std"],scalers=["identity"], deg=neighbor_num_list)
self.layer1_generator = MLP_generator(hid_dim, hid_dim)

# GNN Encoder
# Encoder
self.shared_encoder = backbone(in_channels=hid_dim,
hidden_channels=hid_dim,
num_layers=encoder_layers,
Expand All @@ -118,10 +119,41 @@ def __init__(self,
act=act,
**kwargs)


self.loss_func = double_recon_loss
# Decoder
self.degree_decoder = FNN_GAD_NR(hid_dim, hid_dim, 1, 4)
self.feature_decoder = FNN_GAD_NR(hid_dim, hid_dim, in_dim, 3)
self.degree_loss_func = nn.MSELoss()
self.feature_loss_func = nn.MSELoss()
self.pool = mp.Pool(4)
self.in_dim = in_dim
self.sample_size = sample_size
self.init_projection = FNN_GAD_NR(in_dim, hid_dim, hid_dim, 1)
self.emb = None


# Sample neighbors from neighbor set, if the length of neighbor set less than sample size, then do the padding.
def sample_neighbors(self, indexes, neighbor_dict, gt_embeddings):
sampled_embeddings_list = []
mark_len_list = []
for index in indexes:
sampled_embeddings = []
neighbor_indexes = neighbor_dict[index]
if len(neighbor_indexes) < self.sample_size:
mask_len = len(neighbor_indexes)
sample_indexes = neighbor_indexes
else:
sample_indexes = random.sample(neighbor_indexes, self.sample_size)
mask_len = self.sample_size
for index in sample_indexes:
sampled_embeddings.append(gt_embeddings[index].tolist())
if len(sampled_embeddings) < self.sample_size:
for _ in range(self.sample_size - len(sampled_embeddings)):
sampled_embeddings.append(torch.zeros(self.out_dim).tolist())
sampled_embeddings_list.append(sampled_embeddings)
mark_len_list.append(mask_len)

return sampled_embeddings_list, mark_len_list

def forward(self, x, edge_index):
"""
Forward computation.
Expand All @@ -143,6 +175,7 @@ def forward(self, x, edge_index):

# feature projection
x = self.linear(x)
# TODO add extra projection for GIN model

# encode feature matrix
self.emb = self.shared_encoder(x, edge_index)
Expand All @@ -155,6 +188,73 @@ def forward(self, x, edge_index):

return x_, s_

def loss_func(self,
gij,
ground_truth_degree_matrix,
h0,
neighbor_dict,
device,
h,
edge_index):
"""
Obtain the dense adjacency matrix of the graph.
Parameters
----------
data : torch_geometric.data.Data
Input graph.
"""

# TODO dissecting the decoders and put it into the forward function

# Degree decoder below:
tot_nodes = gij.shape[0]
degree_logits = self.degree_decoding(gij)
ground_truth_degree_matrix = torch.unsqueeze(ground_truth_degree_matrix, dim=1)
degree_loss = self.degree_loss_func(degree_logits, ground_truth_degree_matrix.float())
degree_loss_per_node = (degree_logits-ground_truth_degree_matrix).pow(2)
_, degree_masks = torch.max(degree_logits.data, dim=1)
h_loss = 0
feature_loss = 0

# layer 1
loss_list = []
loss_list_per_node = []
feature_loss_list = []
# Sample multiple times to remove noise
for _ in range(3):
local_index_loss_sum = 0
local_index_loss_sum_per_node = []
indexes = []
h0_prime = self.feature_decoder(gij)
feature_losses = self.feature_loss_func(h0, h0_prime)
feature_losses_per_node = (h0-h0_prime).pow(2).mean(1)
feature_loss_list.append(feature_losses_per_node)

local_index_loss, local_index_loss_per_node = self.reconstruction_neighbors2(gij,h0,edge_index)

loss_list.append(local_index_loss)
loss_list_per_node.append(local_index_loss_per_node)

loss_list = torch.stack(loss_list)
h_loss += torch.mean(loss_list)

loss_list_per_node = torch.stack(loss_list_per_node)
h_loss_per_node = torch.mean(loss_list_per_node,dim=0)

feature_loss_per_node = torch.mean(torch.stack(feature_loss_list),dim=0)
feature_loss += torch.mean(torch.stack(feature_loss_list))

h_loss_per_node = h_loss_per_node.reshape(tot_nodes,1)
degree_loss_per_node = degree_loss_per_node.reshape(tot_nodes,1)
feature_loss_per_node = feature_loss_per_node.reshape(tot_nodes,1)

loss = self.lambda_loss1 * h_loss + degree_loss * self.lambda_loss3 + self.lambda_loss2 * feature_loss
loss_per_node = self.lambda_loss1 * h_loss_per_node + degree_loss_per_node * self.lambda_loss3 + self.lambda_loss2 * feature_loss_per_node

return loss,loss_per_node,h_loss_per_node,degree_loss_per_node,feature_loss_per_node


@staticmethod
def process_graph(data):
"""
Expand Down
Loading

0 comments on commit fecec02

Please sign in to comment.