-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlstm.m
125 lines (97 loc) · 2.64 KB
/
lstm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
function [dWf,dRf,dbf,dWi,dRi,dbi,dWg,dRg,dbg,dWo,dRo,dbo,dV,db, hend, cend, loss, y_hat] = lstm(...
Wf,Rf,bf,Wi,Ri,bi,Wg,Rg,bg,Wo,Ro,bo,V,b,inputs, targets, h0, c0)
T = size(inputs, 2); % length of LSTM network
if (T ~= size(targets, 2))
error('Number of input samples and desired output samples does not match.');
end
H = size(Rf,1); % number of hidden nodes
D = size(V, 1); % number of output nodes
h = zeros(H, T);
c = zeros(H, T);
f = zeros(H, T);
i = zeros(H, T);
g = zeros(H, T);
o = zeros(H, T);
y = zeros(D, T);
ht = h0;
ct = c0;
loss = 0;
for j = 1:T
xt = inputs(:, j);
dt = targets(:, j);
ft = logsig(Wf*xt + Rf*ht + bf);
it = logsig(Wi*xt + Ri*ht + bi);
gt = tanh(Wg*xt + Rg*ht + bg);
ot = logsig(Wo*xt + Ro*ht + bo);
ct = ft.*ct + it.*gt;
ht = ot.*tanh(ct);
yt = V*ht + b;
loss = loss + 0.5*sum((dt-yt).^2);
c(:, j) = ct;
h(:, j) = ht;
f(:, j) = ft;
i(:, j) = it;
g(:, j) = gt;
o(:, j) = ot;
y(:, j) = yt;
end
dWf = zeros(size(Wf)); dRf = zeros(size(Rf)); dbf = zeros(size(bf));
dWi = zeros(size(Wi)); dRi = zeros(size(Ri)); dbi = zeros(size(bi));
dWg = zeros(size(Wg)); dRg = zeros(size(Rg)); dbg = zeros(size(bg));
dWo = zeros(size(Wo)); dRo = zeros(size(Ro)); dbo = zeros(size(bo));
dV = zeros(size(V)); db = zeros(size(b));
dhnext = zeros(size(h0));
dcnext = zeros(size(c0));
for j = T:-1:1
xt = inputs(:,j);
dt = targets(:,j);
yt = y(:,j);
it = i(:,j);
ot = o(:,j);
ft = f(:,j);
gt = g(:,j);
ct = c(:,j);
dyt = -(dt - yt);
dht = V'*dyt + dhnext;
dct = dht.*ot.*(1-ct.^2) + dcnext;
dot = dht.*tanh(ct);
dit = dct.*gt;
dgt = dct.*it;
if j==1
dft = dct.*c0;
else
dft = dct.*c(:,j-1);
end
dV = dV + dyt*ht';
db = db + dyt;
dff = dft.*ft.*(1-ft);
dii = dit.*it.*(1-it);
dgg = dgt.*(1-gt.^2);
doo = dot.*ot.*(1-ot);
dWf = dWf + dff*xt';
dWi = dWi + dii*xt';
dWg = dWg + dgg*xt';
dWo = dWo + doo*xt';
dbf = dbf + dff;
dbi = dbi + dii;
dbg = dbg + dgg;
dbo = dbo + doo;
if j==1
dRf = dRf + dff*h0';
dRi = dRi + dii*h0';
dRg = dRg + dgg*h0';
dRo = dRo + doo*h0';
else
dRf = dRf + dff*h(:,j-1)';
dRi = dRi + dii*h(:,j-1)';
dRg = dRg + dgg*h(:,j-1)';
dRo = dRo + doo*h(:,j-1)';
end
dcnext = dct.*ft;
dhnext = Rf'*dft.*ft.*(1-ft) + Ri'*dit.*it.*(1-it) + ...
Rg'*dgt.*(1-gt.^2) + Ro'*dot.*ot.*(1-ot);
end
hend = h(:, end);
cend = c(:, end);
y_hat = y;
end