-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathgmminit.m
94 lines (86 loc) · 3.26 KB
/
gmminit.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
function mix = gmminit(mix, x, options)
%GMMINIT Initialises Gaussian mixture model from data
%
% Description
% MIX = GMMINIT(MIX, X, OPTIONS) uses a dataset X to initialise the
% parameters of a Gaussian mixture model defined by the data structure
% MIX. The k-means algorithm is used to determine the centres. The
% priors are computed from the proportion of examples belonging to each
% cluster. The covariance matrices are calculated as the sample
% covariance of the points associated with (i.e. closest to) the
% corresponding centres. For a mixture of PPCA model, the PPCA
% decomposition is calculated for the points closest to a given centre.
% This initialisation can be used as the starting point for training
% the model using the EM algorithm.
%
% See also
% GMM
%
% Copyright (c) Ian T Nabney (1996-2001)
[ndata, xdim] = size(x);
% Check that inputs are consistent
errstring = consist(mix, 'gmm', x);
if ~isempty(errstring)
error(errstring);
end
% Arbitrary width used if variance collapses to zero: make it 'large' so
% that centre is responsible for a reasonable number of points.
GMM_WIDTH = 1.0;
% Use kmeans algorithm to set centres
options(5) = 1;
[mix.centres, options, post] = kmeans(mix.centres, x, options);
% Set priors depending on number of points in each cluster
cluster_sizes = max(sum(post, 1), 1); % Make sure that no prior is zero
mix.priors = cluster_sizes/sum(cluster_sizes); % Normalise priors
switch mix.covar_type
case 'spherical'
if mix.ncentres > 1
% Determine widths as distance to nearest centre
% (or a constant if this is zero)
cdist = dist2(mix.centres, mix.centres);
cdist = cdist + diag(ones(mix.ncentres, 1)*realmax);
mix.covars = min(cdist);
mix.covars = mix.covars + GMM_WIDTH*(mix.covars < eps);
else
% Just use variance of all data points averaged over all
% dimensions
mix.covars = mean(diag(cov(x)));
end
case 'diag'
for j = 1:mix.ncentres
% Pick out data points belonging to this centre
c = x(find(post(:, j)),:);
diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
mix.covars(j, :) = sum((diffs.*diffs), 1)/size(c, 1);
% Replace small entries by GMM_WIDTH value
mix.covars(j, :) = mix.covars(j, :) + GMM_WIDTH.*(mix.covars(j, :)<eps);
end
case 'full'
for j = 1:mix.ncentres
% Pick out data points belonging to this centre
c = x(find(post(:, j)),:);
diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
mix.covars(:,:,j) = (diffs'*diffs)/(size(c, 1));
% Add GMM_WIDTH*Identity to rank-deficient covariance matrices
if rank(mix.covars(:,:,j)) < mix.nin
mix.covars(:,:,j) = mix.covars(:,:,j) + GMM_WIDTH.*eye(mix.nin);
end
end
case 'ppca'
for j = 1:mix.ncentres
% Pick out data points belonging to this centre
c = x(find(post(:,j)),:);
diffs = c - (ones(size(c, 1), 1) * mix.centres(j, :));
[tempcovars, tempU, templambda] = ...
ppca((diffs'*diffs)/size(c, 1), mix.ppca_dim);
if length(templambda) ~= mix.ppca_dim
error('Unable to extract enough components');
else
mix.covars(j) = tempcovars;
mix.U(:, :, j) = tempU;
mix.lambda(j, :) = templambda;
end
end
otherwise
error(['Unknown covariance type ', mix.covar_type]);
end