-
Notifications
You must be signed in to change notification settings - Fork 17
/
demmdn1.m
211 lines (191 loc) · 6.14 KB
/
demmdn1.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
%DEMMDN1 Demonstrate fitting a multi-valued function using a Mixture Density Network.
%
% Description
% The problem consists of one input variable X and one target variable
% T with data generated by sampling T at equal intervals and then
% generating target data by computing T + 0.3*SIN(2*PI*T) and adding
% Gaussian noise. A Mixture Density Network with 3 centres in the
% mixture model is trained by minimizing a negative log likelihood
% error function using the scaled conjugate gradient optimizer.
%
% The conditional means, mixing coefficients and variances are plotted
% as a function of X, and a contour plot of the full conditional
% density is also generated.
%
% See also
% MDN, MDNERR, MDNGRAD, SCG
%
% Copyright (c) Ian T Nabney (1996-2001)
% Generate the matrix of inputs x and targets t.
seedn = 42;
seed = 42;
randn('state', seedn);
rand('state', seed);
ndata = 300; % Number of data points.
noise = 0.2; % Range of noise distribution.
t = [0:1/(ndata - 1):1]';
x = t + 0.3*sin(2*pi*t) + noise*rand(ndata, 1) - noise/2;
axis_limits = [-0.2 1.2 -0.2 1.2];
clc
disp('This demonstration illustrates the use of a Mixture Density Network')
disp('to model multi-valued functions. The data is generated from the')
disp('mapping x = t + 0.3 sin(2 pi t) + e, where e is a noise term.')
disp('We begin by plotting the data.')
disp(' ')
disp('Press any key to continue')
pause
% Plot the data
fh1 = figure;
p1 = plot(x, t, 'ob');
axis(axis_limits);
hold on
disp('Note that for x in the range 0.35 to 0.65, there are three possible')
disp('branches of the function.')
disp(' ')
disp('Press any key to continue')
pause
% Set up network parameters.
nin = 1; % Number of inputs.
nhidden = 5; % Number of hidden units.
ncentres = 3; % Number of mixture components.
dim_target = 1; % Dimension of target space
mdntype = '0'; % Currently unused: reserved for future use
alpha = 100; % Inverse variance for weight initialisation
% Make variance small for good starting point
% Create and initialize network weight vector.
net = mdn(nin, nhidden, ncentres, dim_target, mdntype);
init_options = zeros(1, 18);
init_options(1) = -1; % Suppress all messages
init_options(14) = 10; % 10 iterations of K means in gmminit
net = mdninit(net, alpha, t, init_options);
% Set up vector of options for the optimiser.
options = foptions;
options(1) = 1; % This provides display of error values.
options(14) = 200; % Number of training cycles.
clc
disp('We initialise the neural network model, which is an MLP with a')
disp('Gaussian mixture model with three components and spherical variance')
disp('as the error function. This enables us to model the complete')
disp('conditional density function.')
disp(' ')
disp('Next we train the model for 200 epochs using a scaled conjugate gradient')
disp('optimizer. The error function is the negative log likelihood of the')
disp('training data.')
disp(' ')
disp('Press any key to continue.')
pause
% Train using scaled conjugate gradients.
[net, options] = netopt(net, options, x, t, 'scg');
disp(' ')
disp('Press any key to continue.')
pause
clc
disp('We can also train a conventional MLP with sum of squares error function.')
disp('This will approximate the conditional mean, which is not always a')
disp('good representation of the data. Note that the error function is the')
disp('sum of squares error on the training data, which accounts for the')
disp('different values from training the MDN.')
disp(' ')
disp('We train the network with the quasi-Newton optimizer for 80 epochs.')
disp(' ')
disp('Press any key to continue.')
pause
mlp_nhidden = 8;
net2 = mlp(nin, mlp_nhidden, dim_target, 'linear');
options(14) = 80;
[net2, options] = netopt(net2, options, x, t, 'quasinew');
disp(' ')
disp('Press any key to continue.')
pause
clc
disp('Now we plot the underlying function, the MDN prediction,')
disp('represented by the mode of the conditional distribution, and the')
disp('prediction of the conventional MLP.')
disp(' ')
disp('Press any key to continue.')
pause
% Plot the original function, and the trained network function.
plotvals = [0:0.01:1]';
mixes = mdn2gmm(mdnfwd(net, plotvals));
axis(axis_limits);
yplot = t+0.3*sin(2*pi*t);
p2 = plot(yplot, t, '--y');
% Use the mode to represent the function
y = zeros(1, length(plotvals));
priors = zeros(length(plotvals), ncentres);
c = zeros(length(plotvals), 3);
widths = zeros(length(plotvals), ncentres);
for i = 1:length(plotvals)
[m, j] = max(mixes(i).priors);
y(i) = mixes(i).centres(j,:);
c(i,:) = mixes(i).centres';
end
p3 = plot(plotvals, y, '*r');
p4 = plot(plotvals, mlpfwd(net2, plotvals), 'g');
set(p4, 'LineWidth', 2);
legend([p1 p2 p3 p4], 'data', 'function', 'MDN mode', 'MLP mean', 4);
hold off
clc
disp('We can also plot how the mixture model parameters depend on x.')
disp('First we plot the mixture centres, then the priors and finally')
disp('the variances.')
disp(' ')
disp('Press any key to continue.')
pause
fh2 = figure;
subplot(3, 1, 1)
plot(plotvals, c)
hold on
title('Mixture centres')
legend('centre 1', 'centre 2', 'centre 3')
hold off
priors = reshape([mixes.priors], mixes(1).ncentres, size(mixes, 2))';
%%fh3 = figure;
subplot(3, 1, 2)
plot(plotvals, priors)
hold on
title('Mixture priors')
legend('centre 1', 'centre 2', 'centre 3')
hold off
variances = reshape([mixes.covars], mixes(1).ncentres, size(mixes, 2))';
%%fh4 = figure;
subplot(3, 1, 3)
plot(plotvals, variances)
hold on
title('Mixture variances')
legend('centre 1', 'centre 2', 'centre 3')
hold off
disp('The last figure is a contour plot of the conditional probability')
disp('density generated by the Mixture Density Network. Note how it')
disp('is well matched to the regions of high data density.')
disp(' ')
disp('Press any key to continue.')
pause
% Contour plot for MDN.
i = 0:0.01:1.0;
j = 0:0.01:1.0;
[I, J] = meshgrid(i,j);
I = I(:);
J = J(:);
li = length(i);
lj = length(j);
Z = zeros(li, lj);
for k = 1:li;
Z(:,k) = gmmprob(mixes(k), j');
end
fh5 = figure;
% Set up levels by hand to make a good figure
v = [2 2.5 3 3.5 5:3:18];
contour(i, j, Z, v)
hold on
title('Contour plot of conditional density')
hold off
disp(' ')
disp('Press any key to exit.')
pause
close(fh1);
close(fh2);
%%close(fh3);
%%close(fh4);
close(fh5);
%%clear all;