1717
1818package org .apache .spark .mllib .clustering
1919
20- import breeze .linalg .{DenseVector => BreezeVector , DenseMatrix => BreezeMatrix }
21- import breeze .linalg .Transpose
20+ import scala .collection .mutable .IndexedSeq
2221
22+ import breeze .linalg .{DenseVector => BreezeVector , DenseMatrix => BreezeMatrix , diag , Transpose }
2323import org .apache .spark .rdd .RDD
2424import org .apache .spark .mllib .linalg .{Matrices , Vector , Vectors }
2525import org .apache .spark .mllib .stat .impl .MultivariateGaussian
2626
27- import scala .collection .mutable .IndexedSeqView
28-
2927/**
3028 * This class performs expectation maximization for multivariate Gaussian
3129 * Mixture Models (GMMs). A GMM represents a composite distribution of
@@ -47,87 +45,34 @@ class GaussianMixtureModelEM private (
4745 private var k : Int ,
4846 private var convergenceTol : Double ,
4947 private var maxIterations : Int ) extends Serializable {
50-
51- // Type aliases for convenience
52- private type DenseDoubleVector = BreezeVector [Double ]
53- private type DenseDoubleMatrix = BreezeMatrix [Double ]
54- private type VectorArrayView = IndexedSeqView [DenseDoubleVector , Array [DenseDoubleVector ]]
55-
56- private type ExpectationSum = (
57- Array [Double ], // log-likelihood in index 0
58- Array [Double ], // array of weights
59- Array [DenseDoubleVector ], // array of means
60- Array [DenseDoubleMatrix ]) // array of cov matrices
61-
62- // create a zero'd ExpectationSum instance
63- private def zeroExpectationSum (k : Int , d : Int ): ExpectationSum = {
64- (Array (0.0 ),
65- new Array [Double ](k),
66- (0 until k).map(_ => BreezeVector .zeros[Double ](d)).toArray,
67- (0 until k).map(_ => BreezeMatrix .zeros[Double ](d,d)).toArray)
68- }
6948
70- // add two ExpectationSum objects (allowed to use modify m1)
71- // (U, U) => U for aggregation
72- private def addExpectationSums (m1 : ExpectationSum , m2 : ExpectationSum ): ExpectationSum = {
73- m1._1(0 ) += m2._1(0 )
74- var i = 0
75- while (i < m1._2.length) {
76- m1._2(i) += m2._2(i)
77- m1._3(i) += m2._3(i)
78- m1._4(i) += m2._4(i)
79- i = i + 1
80- }
81- m1
82- }
49+ /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
50+ def this () = this (2 , 0.01 , 100 )
51+
8352
84- // compute cluster contributions for each input point
85- // (U, T) => U for aggregation
86- private def computeExpectation (
87- weights : Array [Double ],
88- dists : Array [MultivariateGaussian ])
89- (sums : ExpectationSum , x : DenseDoubleVector ): ExpectationSum = {
90- val k = sums._2.length
91- val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(x) }
92- val pSum = p.sum
93- sums._1(0 ) += math.log(pSum)
94- val xxt = x * new Transpose (x)
95- var i = 0
96- while (i < k) {
97- p(i) /= pSum
98- sums._2(i) += p(i)
99- sums._3(i) += x * p(i)
100- sums._4(i) += xxt * p(i)
101- i = i + 1
102- }
103- sums
104- }
10553
10654 // number of samples per cluster to use when initializing Gaussians
10755 private val nSamples = 5
10856
10957 // an initializing GMM can be provided rather than using the
11058 // default random starting point
111- private var initialGmm : Option [GaussianMixtureModel ] = None
112-
113- /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
114- def this () = this (2 , 0.01 , 100 )
59+ private var initialModel : Option [GaussianMixtureModel ] = None
11560
11661 /** Set the initial GMM starting point, bypassing the random initialization.
11762 * You must call setK() prior to calling this method, and the condition
118- * (gmm .k == this.k) must be met; failure will result in an IllegalArgumentException
63+ * (model .k == this.k) must be met; failure will result in an IllegalArgumentException
11964 */
120- def setInitialGmm ( gmm : GaussianMixtureModel ): this .type = {
121- if (gmm .k == k) {
122- initialGmm = Some (gmm )
65+ def setInitialModel ( model : GaussianMixtureModel ): this .type = {
66+ if (model .k == k) {
67+ initialModel = Some (model )
12368 } else {
124- throw new IllegalArgumentException (" initialing GMM has mismatched cluster count (gmm .k != k)" )
69+ throw new IllegalArgumentException (" mismatched cluster count (model .k != k)" )
12570 }
12671 this
12772 }
12873
12974 /** Return the user supplied initial GMM, if supplied */
130- def getInitialGmm : Option [GaussianMixtureModel ] = initialGmm
75+ def getInitialModel : Option [GaussianMixtureModel ] = initialModel
13176
13277 /** Set the number of Gaussians in the mixture model. Default: 2 */
13378 def setK (k : Int ): this .type = {
@@ -161,9 +106,6 @@ class GaussianMixtureModelEM private (
161106 */
162107 def getConvergenceTol : Double = convergenceTol
163108
164- /** Machine precision value used to ensure matrix conditioning */
165- private val eps = math.pow(2.0 , - 52 )
166-
167109 /** Perform expectation maximization */
168110 def run (data : RDD [Vector ]): GaussianMixtureModel = {
169111 val sc = data.sparkContext
@@ -179,70 +121,122 @@ class GaussianMixtureModelEM private (
179121 // we start with uniform weights, a random mean from the data, and
180122 // diagonal covariance matrices using component variances
181123 // derived from the samples
182- val (weights, gaussians) = initialGmm match {
183- case Some (gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map{ case (mu, sigma) =>
124+ val (weights, gaussians) = initialModel match {
125+ case Some (gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case (mu, sigma) =>
184126 new MultivariateGaussian (mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
185- }.toArray )
127+ })
186128
187129 case None => {
188130 val samples = breezeData.takeSample(true , k * nSamples, scala.util.Random .nextInt)
189- (Array .fill[ Double ] (k)(1.0 / k), ( 0 until k).map { i =>
131+ (Array .fill(k)(1.0 / k), Array .tabulate(k) { i =>
190132 val slice = samples.view(i * nSamples, (i + 1 ) * nSamples)
191133 new MultivariateGaussian (vectorMean(slice), initCovariance(slice))
192- }.toArray )
134+ })
193135 }
194136 }
195137
196138 var llh = Double .MinValue // current log-likelihood
197139 var llhp = 0.0 // previous log-likelihood
198140
199141 var iter = 0
200- do {
142+ while (iter < maxIterations && Math .abs(llh - llhp) > convergenceTol) {
201143 // create and broadcast curried cluster contribution function
202- val compute = sc.broadcast(computeExpectation (weights, gaussians)_)
144+ val compute = sc.broadcast(ExpectationSum .add (weights, gaussians)_)
203145
204146 // aggregate the cluster contribution for all sample points
205- val (logLikelihood, wSums, muSums, sigmaSums) =
206- breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
147+ val sums = breezeData.aggregate(ExpectationSum .zero(k, d))(compute.value, _ += _)
207148
208149 // Create new distributions based on the partial assignments
209150 // (often referred to as the "M" step in literature)
210- val sumWeights = wSums.sum
211- for (i <- 0 until k) {
212- val mu = muSums(i) / wSums(i)
213- val sigma = sigmaSums(i) / wSums(i) - mu * new Transpose (mu)
214- weights(i) = wSums(i) / sumWeights
151+ val sumWeights = sums.weights.sum
152+ var i = 0
153+ while (i < k) {
154+ val mu = sums.means(i) / sums.weights(i)
155+ val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose (mu) // TODO: Use BLAS.dsyr
156+ weights(i) = sums.weights(i) / sumWeights
215157 gaussians(i) = new MultivariateGaussian (mu, sigma)
158+ i = i + 1
216159 }
217160
218161 llhp = llh // current becomes previous
219- llh = logLikelihood( 0 ) // this is the freshly computed log-likelihood
162+ llh = sums. logLikelihood // this is the freshly computed log-likelihood
220163 iter += 1
221- } while (iter < maxIterations && Math .abs(llh - llhp) > convergenceTol)
164+ }
222165
223166 // Need to convert the breeze matrices to MLlib matrices
224- val means = ( 0 until k).map( i => Vectors .fromBreeze(gaussians(i).mu)).toArray
225- val sigmas = ( 0 until k).map( i => Matrices .fromBreeze(gaussians(i).sigma)).toArray
167+ val means = Array .tabulate(k) { i => Vectors .fromBreeze(gaussians(i).mu) }
168+ val sigmas = Array .tabulate(k) { i => Matrices .fromBreeze(gaussians(i).sigma) }
226169 new GaussianMixtureModel (weights, means, sigmas)
227170 }
228171
229172 /** Average of dense breeze vectors */
230- private def vectorMean (x : VectorArrayView ): DenseDoubleVector = {
173+ private def vectorMean (x : IndexedSeq [ BreezeVector [ Double ]] ): BreezeVector [ Double ] = {
231174 val v = BreezeVector .zeros[Double ](x(0 ).length)
232175 x.foreach(xi => v += xi)
233- v / x.length.asInstanceOf [ Double ]
176+ v / x.length.toDouble
234177 }
235178
236179 /**
237180 * Construct matrix where diagonal entries are element-wise
238181 * variance of input vectors (computes biased variance)
239182 */
240- private def initCovariance (x : VectorArrayView ): DenseDoubleMatrix = {
183+ private def initCovariance (x : IndexedSeq [ BreezeVector [ Double ]] ): BreezeMatrix [ Double ] = {
241184 val mu = vectorMean(x)
242185 val ss = BreezeVector .zeros[Double ](x(0 ).length)
243- val cov = BreezeMatrix .eye[Double ](ss.length)
244186 x.map(xi => (xi - mu) :^ 2.0 ).foreach(u => ss += u)
245- (0 until ss.length).foreach(i => cov(i,i) = ss(i) / x.length)
246- cov
187+ diag(ss / x.length.toDouble)
247188 }
248189}
190+
191+ // companion class to provide zero constructor for ExpectationSum
192+ private object ExpectationSum {
193+ private val eps = math.pow(2.0 , - 52 )
194+
195+ def zero (k : Int , d : Int ): ExpectationSum = {
196+ new ExpectationSum (0.0 , Array .fill(k)(0.0 ),
197+ Array .fill(k)(BreezeVector .zeros(d)), Array .fill(k)(BreezeMatrix .zeros(d,d)))
198+ }
199+
200+ // compute cluster contributions for each input point
201+ // (U, T) => U for aggregation
202+ def add (
203+ weights : Array [Double ],
204+ dists : Array [MultivariateGaussian ])
205+ (sums : ExpectationSum , x : BreezeVector [Double ]): ExpectationSum = {
206+ val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(x) }
207+ val pSum = p.sum
208+ sums.logLikelihood += math.log(pSum)
209+ val xxt = x * new Transpose (x)
210+ var i = 0
211+ while (i < sums.k) {
212+ p(i) /= pSum
213+ sums.weights(i) += p(i)
214+ sums.means(i) += x * p(i)
215+ sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
216+ i = i + 1
217+ }
218+ sums
219+ }
220+ }
221+
222+ // Aggregation class for partial expectation results
223+ private class ExpectationSum (
224+ var logLikelihood : Double ,
225+ val weights : Array [Double ],
226+ val means : Array [BreezeVector [Double ]],
227+ val sigmas : Array [BreezeMatrix [Double ]]) extends Serializable {
228+
229+ val k = weights.length
230+
231+ def += (x : ExpectationSum ): ExpectationSum = {
232+ var i = 0
233+ while (i < k) {
234+ weights(i) += x.weights(i)
235+ means(i) += x.means(i)
236+ sigmas(i) += x.sigmas(i)
237+ i = i + 1
238+ }
239+ logLikelihood += x.logLikelihood
240+ this
241+ }
242+ }
0 commit comments