-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEigenGPARD_negLogLik.m
103 lines (90 loc) · 3.74 KB
/
EigenGPARD_negLogLik.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
% Compute the negative log likelihood and its derivative respect to eacch
% model paramter. We use an ARD kernel plus a linear kernel:
% k(x,y) = a0*exp(-(x-y)'*diag(eta)*(x-y))+a1*x'*y+a2
% paramters:
% param - current parameters for the model
% element 1: log(sigma)
% element 2 - (1+D): log(eta)
% element (D+2): log(a0)
% element (D+3) - (D+2+D*M): B (reshaped B matrix as a vector)
% X - input data point
% N by D matrix, where each row is a data point
% t - labels
% N by 1 vector
% M - number of basis point used
function [f df] = EigenGPARD_negLogLik(param, X, t, M)
[N D] = size(X);
% load parameters
sigma2 = exp(2*param(1));
eta = exp(param(2:1+D));
a0 = exp(param(2+D));
B = reshape(param(3+D:D*M+D+2), M, D);
% to avoid semi positive definite
epsilon = 1e-10;
% Some commonly used terms
X2 = X.*X;
B2 = B.*B;
X_eta = bsxfun(@times,X,eta');
B_eta = bsxfun(@times,B,eta');
% Compute gram matrices
expH = exp(bsxfun(@minus,bsxfun(@minus,2*X_eta*B',X2*eta),(B2*eta)'));
Kxb = a0*expH;
expF = exp(bsxfun(@minus,bsxfun(@minus,2*B_eta*B',B2*eta),(B2*eta)'));
Kbb = a0*expF+epsilon*eye(M);
% Define Q = Kbb + 1/sigma2 * Kbx *Kxb
Q = Kbb+(Kxb'*Kxb)/sigma2;
% Cholesky factorization for stable computation
cholKbb = chol(Kbb,'lower');
cholQ = chol(Q,'lower');
% Other commonly used terms
lowerOpt.LT = true; upperOpt.LT = true; upperOpt.TRANSA = true;
invCholQ_Kbx_invSigma2 = linsolve(cholQ,Kxb'/sigma2,lowerOpt);
invCholQ_Kbx_invSigma2_t = invCholQ_Kbx_invSigma2*t;
diagInvCN = 1/sigma2-sum(invCholQ_Kbx_invSigma2.^2, 1)';
invCN_t = t/sigma2-invCholQ_Kbx_invSigma2'*invCholQ_Kbx_invSigma2_t;
% compute negative log likelihood function f = (ln|CN|+t'*CN*t+ln(2*pi))/2
f = sum(log(diag(cholQ)))-sum(log(diag(cholKbb)))+(log(sigma2)*N+...
t'*t/sigma2-invCholQ_Kbx_invSigma2_t'*invCholQ_Kbx_invSigma2_t...
+N*log(2*pi))/2;
%f = sum(log(diag(cholQ)))-sum(log(diag(cholKbb)))+(log(sigma2)*N)/2;
%-----------------------
% compute gradient
%-----------------------
% prepare things that may be used later
invKbb_Kbx_invCN = linsolve(cholQ,invCholQ_Kbx_invSigma2,upperOpt);
invKbb_Kbx_invCN_Kxb_invKbb = linsolve(cholKbb, linsolve(cholKbb, Kxb'*invKbb_Kbx_invCN',lowerOpt),upperOpt)';
%invKbb_Kbx_invCN_Kxb_invKbb = inv(Kbb) - inv(Q)
invKbb_Kbx_invCN_t = invKbb_Kbx_invCN*t;
invKbb_Kbx_invCN_t_t_invCN = invKbb_Kbx_invCN_t*invCN_t';
invKbb_Kbx_invCN_t_t_invCN_Kxb_invKbb = invKbb_Kbx_invCN_t*invKbb_Kbx_invCN_t';
R1 = invKbb_Kbx_invCN.*(a0*expH)';
S1 = invKbb_Kbx_invCN_Kxb_invKbb.*(a0*expF);
R2 = invKbb_Kbx_invCN_t_t_invCN.*(a0*expH)';
S2 = invKbb_Kbx_invCN_t_t_invCN_Kxb_invKbb.*(a0*expF);
% compute dlogSigma
dlogSigma = sigma2*(sum(diagInvCN)-invCN_t'*invCN_t);
% compute dlogEta
% part1 = tr(inv(CN)*dCN)
part1 = 2*(2*sum(B'.*(X'*R1'), 2)-B2'*sum(R1,2)-X2'*sum(R1,1)')...
+(-2*sum(B.*(S1*B),1)'+2*B2'*sum(S1, 1)');
% part2 = tr(inv(CN)*t*t'*inv(CN)*dCN)
part2 = 2*(2*sum(B'.*(X'*R2'), 2)-B2'*sum(R2,2)-X2'*sum(R2,1)')...
+(-2*sum(B.*(S2*B),1)'+2*B2'*sum(S2, 1)');
dlogEta = eta.*(part1-part2)/2;
% compute dlogA0
% part1 = tr(inv(CN)*dCN)
part1 = 2*sum(sum(invKbb_Kbx_invCN.*expH'))-sum(sum(invKbb_Kbx_invCN_Kxb_invKbb.*expF'));
%part1 = 2*sum(sum(invKbb_Kbx_invCN.*expH'))-trace(invKbb_Kbx_invCN_Kxb_invKbb_expF);
% part2 = tr(inv(CN)*t*t'*inv(CN)*dCN)
part2 = 2*sum(sum(invKbb_Kbx_invCN_t_t_invCN.*expH'))-sum(sum(invKbb_Kbx_invCN_t_t_invCN_Kxb_invKbb.*expF'));
dlogA0 = a0*(part1-part2)/2;
% compute dB
% part1 = tr(inv(CN)*dCN)/2
part1 = 2*(2*R1*X_eta-2*repmat(sum(R1,2),1,D).*B_eta)...
+(-4*S1*B_eta+4*repmat(sum(S1,2),1,D).*B_eta);
part2 = 2*(2*R2*X_eta-2*repmat(sum(R2,2),1,D).*B_eta)...
+(-4*S2*B_eta+4*repmat(sum(S2,2),1,D).*B_eta);
dB = (part1-part2)/2;
% combine all gradients in a vector
df = [dlogSigma; dlogEta; dlogA0; reshape(dB,D*M,1)];
end