-
Notifications
You must be signed in to change notification settings - Fork 1
/
center.py
52 lines (38 loc) · 1.34 KB
/
center.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from torch.autograd import Variable
import numpy as np
import torch
def Relaxcenter(Y, B, V, mu, vul, nta):
"""
GD: relax C to V. Corresponding to the relaxation optimization approach
"""
alpha = 0.03
num = 1
T = 2*torch.eye(Y.size(0)) - torch.ones(Y.size(0))
TK = V.size(0)*T
TK = torch.FloatTensor(Variable(TK, requires_grad = False))
for i in range(200):
intra_loss = (V@Y - B).pow(2).mean()
inter_loss = (V.t()@V - TK.cuda()).pow(2).mean()
quantization_loss = (V - V.sign()).pow(2).mean()
loss = intra_loss + (vul) * inter_loss + (nta) * quantization_loss
loss.backward()
num += 1
if num ==150 or num ==180:
alpha = alpha*0.1
V.data = V.data - alpha * V.grad.data
V.grad.data.zero_()
V_u = V.data.cpu()
Center_u = V_u.sign()
return Center_u, V_u
def Discretecenter(Y, B, C, mu, vul):
"""Solve DCC(Discrete Cyclic Coordinate Descent) problem.
"""
ones_vector = torch.ones([C.size(0) - 1])
for i in range(C.shape[0]):
Q = Y @ B.t()
q = Q[i, :]
v = Y[i, :]
Y_prime = torch.cat((Y[:i, :], Y[i+1:, :]))
C_prime = torch.cat((C[:i, :], C[i+1:, :]))
C[i, :] = (q - C_prime.t() @ Y_prime @ v - vul *C_prime.t()@ones_vector).sign()
return C.t()