-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
66 lines (52 loc) · 2.03 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torch import nn
class DiagonalNet(nn.Module):
def __init__(self, alpha, L, dimD):
super().__init__()
self.u = nn.Parameter(alpha / ((dimD * 2) ** 0.5) * torch.ones(dimD))
self.v = nn.Parameter(alpha / ((dimD * 2) ** 0.5) * torch.ones(dimD))
self.L = L
def get_w(self):
return self.u ** self.L - self.v ** self.L
def forward(self, x):
return (x @ self.get_w()).unsqueeze(-1)
class HomoMLP(nn.Module):
def __init__(self, init_scale, L, dimD, dimH, dimO, first_layer_bias, init_method='he'):
super().__init__()
self.L = L
self.dimD = dimD
self.dimH = dimH
self.dimO = dimO
self.layers = []
seq = []
dimLast = self.dimD
for k in range(self.L):
dimNext = self.dimH if k < self.L - 1 else self.dimO
l = nn.Linear(dimLast, dimNext, bias=(k == 0 and first_layer_bias))
self.layers.append(l)
seq.append(l)
dimLast = dimNext
if k < self.L - 1:
seq.append(nn.ReLU(inplace=True))
self.net = nn.Sequential(*seq)
if init_method == 'he':
for i, l in enumerate(self.layers):
if i < self.L - 1:
torch.nn.init.kaiming_normal_(l.weight.data, nonlinearity='relu')
else:
torch.nn.init.kaiming_normal_(l.weight.data, nonlinearity='linear')
if l.bias is not None:
l.bias.data.zero_()
for i, l in enumerate(self.layers):
l.weight.data.mul_(init_scale)
def forward(self, x):
return self.net(x)
class MatrixFactorization(nn.Module):
def __init__(self, alpha, dimD):
super().__init__()
self.dimD = dimD
self.U = nn.Parameter(alpha * torch.eye(dimD))
self.V = nn.Parameter(alpha * torch.eye(dimD))
def forward(self, x1, x2):
M = self.U @ self.U.T - self.V @ self.V.T
return M[x1, x2]