forked from zhongpeixiang/RGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_simple.py
137 lines (113 loc) · 5.5 KB
/
model_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch_geometric.nn import SGConv, global_add_pool
from torch_scatter import scatter_add
def maybe_num_nodes(index, num_nodes=None):
return index.max().item() + 1 if num_nodes is None else num_nodes
def add_remaining_self_loops(edge_index,
edge_weight=None,
fill_value=1,
num_nodes=None):
num_nodes = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index
mask = row != col
inv_mask = 1 - mask
loop_weight = torch.full(
(num_nodes, ),
fill_value,
dtype=None if edge_weight is None else edge_weight.dtype,
device=edge_index.device)
if edge_weight is not None:
assert edge_weight.numel() == edge_index.size(1)
remaining_edge_weight = edge_weight[inv_mask]
if remaining_edge_weight.numel() > 0:
loop_weight[row[inv_mask]] = remaining_edge_weight
edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0)
loop_index = torch.arange(0, num_nodes, dtype=row.dtype, device=row.device)
loop_index = loop_index.unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1)
return edge_index, edge_weight
class NewSGConv(SGConv):
def __init__(self, num_features, num_classes, K=1, cached=False,
bias=True):
super(NewSGConv, self).__init__(num_features, num_classes, K=K, cached=cached, bias=bias)
# allow negative edge weights
@staticmethod
def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ),
dtype=dtype,
device=edge_index.device)
fill_value = 1 if not improved else 2
edge_index, edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
row, col = edge_index
deg = scatter_add(torch.abs(edge_weight), row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index, edge_weight=None):
""""""
if not self.cached or self.cached_result is None:
edge_index, norm = NewSGConv.norm(
edge_index, x.size(0), edge_weight, dtype=x.dtype)
for k in range(self.K):
x = self.propagate(edge_index, x=x, norm=norm)
self.cached_result = x
return self.lin(self.cached_result)
def message(self, x_j, norm):
# x_j: (batch_size*num_nodes*num_nodes, num_features)
# norm: (batch_size*num_nodes*num_nodes, )
return norm.view(-1, 1) * x_j
class ReverseLayerF(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class SymSimGCNNet(torch.nn.Module):
def __init__(self, num_nodes, learn_edge_weight, edge_weight, num_features, num_hiddens, num_classes, K, dropout=0.5, domain_adaptation=""):
"""
num_nodes: number of nodes in the graph
learn_edge_weight: if True, the edge_weight is learnable
edge_weight: initial edge matrix
num_features: feature dim for each node/channel
num_hiddens: a tuple of hidden dimensions
num_classes: number of emotion classes
K: number of layers
dropout: dropout rate in final linear layer
domain_adaptation: RevGrad
"""
super(SymSimGCNNet, self).__init__()
self.domain_adaptation = domain_adaptation
self.num_nodes = num_nodes
self.xs, self.ys = torch.tril_indices(self.num_nodes, self.num_nodes, offset=0)
edge_weight = edge_weight.reshape(self.num_nodes, self.num_nodes)[self.xs, self.ys] # strict lower triangular values
self.edge_weight = nn.Parameter(edge_weight, requires_grad=learn_edge_weight)
self.dropout = dropout
self.conv1 = NewSGConv(num_features=num_features, num_classes=num_hiddens[0], K=K)
self.fc = nn.Linear(num_hiddens[0], num_classes)
if self.domain_adaptation in ["RevGrad"]:
self.domain_classifier = nn.Linear(num_hiddens[0], 2)
def forward(self, data, alpha=0):
batch_size = len(data.y)
x, edge_index = data.x, data.edge_index
edge_weight = torch.zeros((self.num_nodes, self.num_nodes), device=edge_index.device)
edge_weight[self.xs.to(edge_weight.device), self.ys.to(edge_weight.device)] = self.edge_weight
edge_weight = edge_weight + edge_weight.transpose(1,0) - torch.diag(edge_weight.diagonal()) # copy values from lower tri to upper tri
edge_weight = edge_weight.reshape(-1).repeat(batch_size)
x = F.relu(self.conv1(x, edge_index, edge_weight))
# domain classification
domain_output = None
if self.domain_adaptation in ["RevGrad"]:
reverse_x = ReverseLayerF.apply(x, alpha)
domain_output = self.domain_classifier(reverse_x)
x = global_add_pool(x, data.batch, size=batch_size)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc(x)
return x, domain_output