-
Notifications
You must be signed in to change notification settings - Fork 4
/
fungrad.m
84 lines (65 loc) · 2.16 KB
/
fungrad.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
function [grad store] = fungrad(problem, R, store)
% grad = fungrad(problem, R)
% [grad store] = fungrad(problem, R, store)
%
% Computes the gradient of funcost at R.
%
% See also: funcost funhess
%
% Nicolas Boumal, UCLouvain, Nov. 20, 2012.
if ~exist('store', 'var')
store = struct();
end
if isfield(store, 'grad')
grad = store.grad;
return;
end
%% Extract data from problem structure.
n = problem.n;
M = problem.M;
A = problem.A;
maskI = problem.maskI;
maskJ = problem.maskJ;
kappa1 = problem.kappa1;
kappa2 = problem.kappa2;
%% Enforce the availability of data that the cost function is in
% charge of producing.
if ~isfield(store, 'hatZ') || ~isfield(store, 'ell1') || ...
~isfield(store, 'ell2') || ~isfield(store, 'f')
[~, store] = funcost(problem, R, store);
end
hatZ = store.hatZ;
ell1 = store.ell1;
ell2 = store.ell2;
f = store.f;
%% Compute the gradient
df = kappa1.*ell1 + kappa2.*ell2;
g = df ./ f;
g_hatZ = multiscale(g, hatZ);
grad = zeros(size(R));
% We write the code for the gradient in this way to avoid looping over
% M elements, seen as M may be of order N^2 and Matlab doesn't like big
% loops. An equivalent but much slower code is given below, for
% readability.
for k1 = 1 : n
for k2 = 1 : n
grad(k1, k2, :) = maskI * (squeeze(g_hatZ(k1, k2, :))) ...
- maskJ * (squeeze(g_hatZ(k1, k2, :)));
end
end
% for k = 1 : M
% i = I(k);
% j = J(k);
% grad(:, :, i) = grad(:, :, i) + g_hatZ(:, :, k);
% grad(:, :, j) = grad(:, :, j) - g_hatZ(:, :, k);
% end
grad = -grad / M;
%% Project the resulting vector from the ambient space to the tangent
% space of P_A. The vector A contains the indices of anchored
% rotations: they are fixed, hence their gradient component is zero.
grad = .5*(grad - multitransp(grad));
grad(:, :, A) = 0;
%% Store some data for the Hessian function.
store.grad = grad;
store.g = g;
end