Skip to content

Commit

Permalink
Beltrami Guard
Browse files Browse the repository at this point in the history
  • Loading branch information
zknus authored Nov 4, 2022
1 parent ec01d9b commit f42c166
Showing 1 changed file with 244 additions and 0 deletions.
244 changes: 244 additions & 0 deletions grb/model/torch/belguard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import dgl
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
from scipy.sparse import lil_matrix
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize

from grb.model.torch.beltrami import GCNConv
from grb.utils.normalize import GCNAdjNorm


class BELGuard(nn.Module):
def __init__(self,
in_features,
out_features,
hidden_features,
n_layers,
activation=F.relu,
layer_norm=False,
dropout=True,
feat_norm=None,
adj_norm_func=GCNAdjNorm,
drop=0.0,
attention=True):
super(BELGuard, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.feat_norm = feat_norm
self.adj_norm_func = adj_norm_func
if type(hidden_features) is int:
hidden_features = [hidden_features] * (n_layers - 1)
elif type(hidden_features) is list or type(hidden_features) is tuple:
assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers."
n_features = [in_features] + hidden_features + [out_features]

self.layers = nn.ModuleList()
for i in range(n_layers):
if layer_norm:
self.layers.append(nn.LayerNorm(n_features[i]))
self.layers.append(GCNConv(in_features=n_features[i],
out_features=n_features[i + 1],
activation=activation if i != n_layers - 1 else None,
dropout=dropout if i != n_layers - 1 else 0.0))
self.reset_parameters()
self.drop = drop
self.drop_learn = torch.nn.Linear(2, 1)
self.attention = attention

@property
def model_type(self):
return "torch"

def reset_parameters(self):
for layer in self.layers:
layer.reset_parameters()

def forward(self, x, adj):
for layer in self.layers:
if isinstance(layer, nn.LayerNorm):
x = layer(x)
else:
if self.attention:
adj = self.att_coef(x, adj)
x = layer(x, adj)

return x

def att_coef(self, features, adj):
edge_index = adj._indices()

n_node = features.shape[0]
row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu().data.numpy()[:]

features_copy = features.cpu().data.numpy()
sim_matrix = cosine_similarity(X=features_copy, Y=features_copy) # try cosine similarity
sim = sim_matrix[row, col]
sim[sim < 0.1] = 0

"""build a attention matrix"""
att_dense = lil_matrix((n_node, n_node), dtype=np.float32)
att_dense[row, col] = sim
if att_dense[0, 0] == 1:
att_dense = att_dense - sp.diags(att_dense.diagonal(), offsets=0, format="lil")
# normalization, make the sum of each row is 1
att_dense_norm = normalize(att_dense, axis=1, norm='l1')

"""add learnable dropout, make character vector"""
if self.drop:
character = np.vstack((att_dense_norm[row, col].A1,
att_dense_norm[col, row].A1))
character = torch.from_numpy(character.T).to(features.device)
drop_score = self.drop_learn(character)
drop_score = torch.sigmoid(drop_score) # do not use softmax since we only have one element
mm = torch.nn.Threshold(0.5, 0)
drop_score = mm(drop_score)
mm_2 = torch.nn.Threshold(-0.49, 1)
drop_score = mm_2(-drop_score)
drop_decision = drop_score.clone().requires_grad_()
drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32)
drop_matrix[row, col] = drop_decision.cpu().data.numpy().squeeze(-1)
att_dense_norm = att_dense_norm.multiply(drop_matrix.tocsr()) # update, remove the 0 edges

if att_dense_norm[0, 0] == 0: # add the weights of self-loop only add self-loop at the first layer
degree = (att_dense_norm != 0).sum(1).A1
lam = 1 / (degree + 1) # degree +1 is to add itself
self_weight = sp.diags(np.array(lam), offsets=0, format="lil")
att = att_dense_norm + self_weight # add the self loop
else:
att = att_dense_norm

row, col = att.nonzero()
att_adj = np.vstack((row, col))
att_edge_weight = att[row, col]
att_edge_weight = np.exp(att_edge_weight) # exponent, kind of softmax
att_edge_weight = torch.tensor(np.array(att_edge_weight)[0], dtype=torch.float32).to(features.device)
att_adj = torch.tensor(att_adj, dtype=torch.int64).to(features.device)

shape = (n_node, n_node)
new_adj = torch.sparse.FloatTensor(att_adj, att_edge_weight, shape)

return new_adj


class GATGuard(nn.Module):
def __init__(self,
in_features,
out_features,
hidden_features,
n_layers,
n_heads,
activation=F.leaky_relu,
layer_norm=False,
feat_norm=None,
adj_norm_func=GCNAdjNorm,
drop=False,
attention=True,
dropout=0.0):

super(GATGuard, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.feat_norm = feat_norm
self.adj_norm_func = adj_norm_func
if type(hidden_features) is int:
hidden_features = [hidden_features] * (n_layers - 1)
elif type(hidden_features) is list or type(hidden_features) is tuple:
assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers."
n_features = [in_features] + hidden_features + [out_features]

self.layers = nn.ModuleList()
for i in range(n_layers):
if layer_norm:
if i == 0:
self.layers.append(nn.LayerNorm(n_features[i]))
else:
self.layers.append(nn.LayerNorm(n_features[i] * n_heads))
self.layers.append(GATConv(in_feats=n_features[i] * n_heads if i != 0 else n_features[i],
out_feats=n_features[i + 1],
num_heads=n_heads if i != n_layers - 1 else 1,
activation=activation if i != n_layers - 1 else None))
self.drop = drop
self.drop_learn = torch.nn.Linear(2, 1)
self.attention = attention
if dropout > 0.0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None

@property
def model_type(self):
return "dgl"

def forward(self, x, adj):
graph = dgl.from_scipy(adj).to(x.device)
graph.ndata['features'] = x

for i, layer in enumerate(self.layers):
if isinstance(layer, nn.LayerNorm):
x = layer(x)
else:
if self.attention:
adj = self.att_coef(x, adj)
graph = dgl.from_scipy(adj).to(x.device)
graph.ndata['features'] = x
x = layer(graph, x).flatten(1)
if i != len(self.layers) - 1:
if self.dropout is not None:
x = self.dropout(x)

return x

def att_coef(self, features, adj):
adj = adj.tocoo()
n_node = features.shape[0]
row, col = adj.row, adj.col

features_copy = features.cpu().data.numpy()
sim_matrix = cosine_similarity(X=features_copy, Y=features_copy) # try cosine similarity
sim = sim_matrix[row, col]
sim[sim < 0.1] = 0

"""build a attention matrix"""
att_dense = lil_matrix((n_node, n_node), dtype=np.float32)
att_dense[row, col] = sim
if att_dense[0, 0] == 1:
att_dense = att_dense - sp.diags(att_dense.diagonal(), offsets=0, format="lil")
# normalization, make the sum of each row is 1
att_dense_norm = normalize(att_dense, axis=1, norm='l1')

"""add learnable dropout, make character vector"""
if self.drop:
character = np.vstack((att_dense_norm[row, col].A1,
att_dense_norm[col, row].A1))
character = torch.from_numpy(character.T).to(features.device)
drop_score = self.drop_learn(character)
drop_score = torch.sigmoid(drop_score) # do not use softmax since we only have one element
mm = torch.nn.Threshold(0.5, 0)
drop_score = mm(drop_score)
mm_2 = torch.nn.Threshold(-0.49, 1)
drop_score = mm_2(-drop_score)
drop_decision = drop_score.clone().requires_grad_()
drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32)
drop_matrix[row, col] = drop_decision.cpu().data.numpy().squeeze(-1)
att_dense_norm = att_dense_norm.multiply(drop_matrix.tocsr()) # update, remove the 0 edges

if att_dense_norm[0, 0] == 0: # add the weights of self-loop only add self-loop at the first layer
degree = (att_dense_norm != 0).sum(1).A1
lam = 1 / (degree + 1) # degree +1 is to add itself
self_weight = sp.diags(np.array(lam), offsets=0, format="lil")
att = att_dense_norm + self_weight # add the self loop
else:
att = att_dense_norm

row, col = att.nonzero()
att_edge_weight = att[row, col]
att_edge_weight = np.exp(att_edge_weight)
att_edge_weight = np.asarray(att_edge_weight.ravel())[0]
new_adj = sp.csr_matrix((att_edge_weight, (row, col)))

return new_adj

0 comments on commit f42c166

Please sign in to comment.