-
Notifications
You must be signed in to change notification settings - Fork 16
/
nca.lua
157 lines (127 loc) · 4.18 KB
/
nca.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
-- function that implements NCA gradient:
local function nca_grad(W, X, Y, Y_tab, num_dims, lambda)
-- dependencies:
local pkg = require 'metriclearning'
-- process input:
local N = X:size(1)
local D = X:size(2)
W:resize(D, num_dims)
-- compute projected data:
local Z = torch.mm(X, W)
-- compute pairwise square Euclidean distance matrix:
local P = -pkg.mahalanobis_distance(Z)
-- compute similarities:
P:exp()
for n = 1,N do
P[n][n] = 0
end
local eps = 1e-19 * N
P:apply(function(x) if x < eps then return eps else return x end end)
P:cdiv(P:sum(2):expand(N, N))
-- compute log-probabilities:
local log_P = torch.log(P)
for n = 1,N do
log_P[n][n] = 0
end
-- compute NCA cost function:
local C = 0
for n = 1,N do
C = C - log_P[n]:index(1, Y_tab[Y[n]]):sum()
end
C = C / N + lambda * torch.norm(W) * torch.norm(W)
-- allocate some memory:
local dC = torch.zeros(W:size())
local dX = torch.DoubleTensor(X:size())
local dZ = torch.DoubleTensor(Z:size())
local weights = torch.DoubleTensor(N)
-- compute gradient:
for n = 1,N do
-- compute differences in data and embedding:
torch.add(dX, X:narrow(1, n, 1):expand(X:size()), -X) -- is negation allocating new memory?
torch.add(dZ, Z:narrow(1, n, 1):expand(Z:size()), -Z) -- is negation allocating new memory?
-- compute "weights" for final multiplication
local inds = Y_tab[Y[n]]
torch.mul(weights, P[n], -(inds:nElement()) + 1)
weights:indexCopy(1, inds, weights:index(1, inds):add(1)) -- can this be done without memcopy?
weights[n] = weights[n] - 1
weights:resize(N, 1)
-- sum final gradient:
dZ:cmul(torch.expand(weights, dZ:size()))
local tmp = torch.mm(dX:t(), dZ)
dC:addmm(dX:t(), dZ)
end
dC:mul(2 / N)
dC:add(2 * lambda, W)
-- return cost function and gradient:
dC:resize(dC:nElement())
return C, dC
end
-- function that numerically checks gradient of NCA loss:
local function checkgrad(W, X, Y, Y_tab, num_dims, lambda)
-- compute true gradient
local _,dC = nca_grad(W, X, Y, Y_tab, num_dims, lambda)
dC:resize(W:size())
-- compute numeric approximations to gradient
local eps = 1e-7
local dC_est = torch.DoubleTensor(dC:size())
for i = 1,dC:size(1) do
for j = 1,dC:size(2) do
W[i][j] = W[i][j] + eps
local C1 = nca_grad(W, X, Y, Y_tab, num_dims, lambda)
W[i][j] = W[i][j] - 2 * eps
local C2 = nca_grad(W, X, Y, Y_tab, num_dims, lambda)
W[i][j] = W[i][j] + eps
dC_est[i][j] = (C1 - C2) / (2 * eps)
end
end
-- compute errors of final estimate
local diff = torch.norm(dC - dC_est) / torch.norm(dC + dC_est)
print('Error in NCA gradient: ' .. diff)
end
-- function that performs NCA:
local function nca(X, Y, opts)
-- retrieve hyperparameters:
local num_dims = opts.num_dims
local lambda = opts.lambda
-- initialize solution:
local W
if X:size(2) == num_dims then
W = torch.eye(num_dims)
else
W = torch.randn(X:size(2), num_dims) * 0.1
end
-- count how often each label appears:
local label_counts = {}
for n = 1,Y:nElement() do
if label_counts[Y[n]] == nil then
label_counts[Y[n]] = 1
else
label_counts[Y[n]] = label_counts[Y[n]] + 1
end
end
-- build a table with indices per label:
local Y_tab = {}
local num_classes = 0
for key,val in pairs(label_counts) do
Y_tab[key] = torch.LongTensor(label_counts[key])
num_classes = num_classes + 1
end
local cur_counts = torch.ones(num_classes)
for n = 1,Y:nElement() do
Y_tab[Y[n]][cur_counts[Y[n]]] = n
cur_counts[Y[n]] = cur_counts[Y[n]] + 1
end
-- perform numerical check of the gradient:
-- checkgrad(W, X, Y, Y_tab, num_dims, lambda)
-- perform minimization of NCA loss:
local state = {lineSearch = optim.fista, maxIter = 250, maxEval = 500, tolFun = 1e-4, tolX = 1e-4, verbose = true}
local func = function(x)
local C,dC = nca_grad(x, X, Y, Y_tab, num_dims, lambda)
return C,dC
end
W = optim.lbfgs(func, W, state)
-- return linear mapping
return W
end
-- return NCA function:
return nca