Skip to content

Commit 8aaa17d

Browse files
committed
Added additional train() method to companion object for cluster count and tolerance parameters.
Modified cluster initialization strategy to use an initial covariance matrix derived from the sample points used to initialize the mean.
1 parent 676e523 commit 8aaa17d

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximization.scala

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ object GMMExpectationMaximization {
3232
/**
3333
* Trains a GMM using the given parameters
3434
*
35-
* @param data training points stores as RDD[Vector]
35+
* @param data training points stored as RDD[Vector]
3636
* @param k the number of Gaussians in the mixture
3737
* @param maxIterations the maximum number of iterations to perform
3838
* @param delta change in log-likelihood at which convergence is considered achieved
@@ -47,7 +47,7 @@ object GMMExpectationMaximization {
4747
/**
4848
* Trains a GMM using the given parameters
4949
*
50-
* @param data training points stores as RDD[Vector]
50+
* @param data training points stored as RDD[Vector]
5151
* @param k the number of Gaussians in the mixture
5252
* @param maxIterations the maximum number of iterations to perform
5353
*/
@@ -58,7 +58,18 @@ object GMMExpectationMaximization {
5858
/**
5959
* Trains a GMM using the given parameters
6060
*
61-
* @param data training points stores as RDD[Vector]
61+
* @param data training points stored as RDD[Vector]
62+
* @param k the number of Gaussians in the mixture
63+
* @param delta change in log-likelihood at which convergence is considered achieved
64+
*/
65+
def train(data: RDD[Vector], k: Int, delta: Double): GaussianMixtureModel = {
66+
new GMMExpectationMaximization().setK(k).setDelta(delta).run(data)
67+
}
68+
69+
/**
70+
* Trains a GMM using the given parameters
71+
*
72+
* @param data training points stored as RDD[Vector]
6273
* @param k the number of Gaussians in the mixture
6374
*/
6475
def train(data: RDD[Vector], k: Int): GaussianMixtureModel = {
@@ -127,10 +138,12 @@ class GMMExpectationMaximization private (
127138

128139
// C will be array of (weight, mean, covariance) tuples
129140
// we start with uniform weights, a random mean from the data, and
130-
// identity matrices for covariance
141+
// diagonal covariance matrices using component variances
142+
// derived from the samples
131143
var C = (0 until k).map(i => (1.0/k,
132144
vec_mean(samples.slice(i * nSamples, (i + 1) * nSamples)),
133-
BreezeMatrix.eye[Double](d))).toArray
145+
init_cov(samples.slice(i * nSamples, (i + 1) * nSamples)))
146+
).toArray
134147

135148
val acc_w = new Array[Accumulator[Double]](k)
136149
val acc_mu = new Array[Accumulator[DenseDoubleVector]](k)
@@ -216,6 +229,19 @@ class GMMExpectationMaximization private (
216229
v / x.length.asInstanceOf[Double]
217230
}
218231

232+
/**
233+
* Construct matrix where diagonal entries are element-wise
234+
* variance of input vectors (computes biased variance)
235+
*/
236+
private def init_cov(x : Array[DenseDoubleVector]) : DenseDoubleMatrix = {
237+
val mu = vec_mean(x)
238+
val ss = BreezeVector.zeros[Double](x(0).length)
239+
val result = BreezeMatrix.eye[Double](ss.length)
240+
(0 until x.length).map(i => (x(i) - mu) :^ 2.0).foreach(u => ss += u)
241+
(0 until ss.length).foreach(i => result(i,i) = ss(i) / x.length)
242+
result
243+
}
244+
219245
/** AccumulatorParam for Dense Breeze Vectors */
220246
private object DenseDoubleVectorAccumulatorParam extends AccumulatorParam[DenseDoubleVector] {
221247
def zero(initialVector : DenseDoubleVector) : DenseDoubleVector = {

0 commit comments

Comments
 (0)