Skip to content

Commit 9b2fc2a

Browse files
committed
Style improvements
Changed ExpectationSum to a private class
1 parent b97fe00 commit 9b2fc2a

File tree

5 files changed

+108
-115
lines changed

5 files changed

+108
-115
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ object DenseGmmEM {
3838
}
3939

4040
private def run(inputFile: String, k: Int, convergenceTol: Double) {
41-
val conf = new SparkConf().setAppName("Spark EM Sample")
41+
val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
4242
val ctx = new SparkContext(conf)
4343

44-
val data = ctx.textFile(inputFile).map{ line =>
44+
val data = ctx.textFile(inputFile).map { line =>
4545
Vectors.dense(line.trim.split(' ').map(_.toDouble))
4646
}.cache()
4747

@@ -56,8 +56,8 @@ object DenseGmmEM {
5656
}
5757

5858
println("Cluster labels (first <= 100):")
59-
val (responsibilityMatrix, clusterLabels) = clusters.predict(data)
60-
clusterLabels.take(100).foreach{ x =>
59+
val clusterLabels = clusters.predictLabels(data)
60+
clusterLabels.take(100).foreach { x =>
6161
print(" " + x)
6262
}
6363
println()

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.mllib.clustering
2020
import breeze.linalg.{DenseVector => BreezeVector}
2121

2222
import org.apache.spark.rdd.RDD
23-
import org.apache.spark.mllib.linalg.Matrix
24-
import org.apache.spark.mllib.linalg.Vector
23+
import org.apache.spark.mllib.linalg.{Matrix, Vector}
2524
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
2625

2726
/**
@@ -44,10 +43,9 @@ class GaussianMixtureModel(
4443
def k: Int = weight.length
4544

4645
/** Maps given points to their cluster indices. */
47-
def predict(points: RDD[Vector]): (RDD[Array[Double]],RDD[Int]) = {
48-
val responsibilityMatrix = predictMembership(points,mu,sigma,weight,k)
49-
val clusterLabels = responsibilityMatrix.map(r => r.indexOf(r.max))
50-
(responsibilityMatrix, clusterLabels)
46+
def predictLabels(points: RDD[Vector]): RDD[Int] = {
47+
val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k)
48+
responsibilityMatrix.map(r => r.indexOf(r.max))
5149
}
5250

5351
/**
@@ -58,15 +56,16 @@ class GaussianMixtureModel(
5856
points: RDD[Vector],
5957
mu: Array[Vector],
6058
sigma: Array[Matrix],
61-
weight: Array[Double], k: Int): RDD[Array[Double]] = {
59+
weight: Array[Double],
60+
k: Int): RDD[Array[Double]] = {
6261
val sc = points.sparkContext
63-
val dists = sc.broadcast{
64-
(0 until k).map{ i =>
62+
val dists = sc.broadcast {
63+
(0 until k).map { i =>
6564
new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
6665
}.toArray
6766
}
6867
val weights = sc.broadcast(weight)
69-
points.map{ x =>
68+
points.map { x =>
7069
computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
7170
}
7271
}
@@ -86,7 +85,7 @@ class GaussianMixtureModel(
8685
k: Int): Array[Double] = {
8786
val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(pt) }
8887
val pSum = p.sum
89-
for (i <- 0 until k){
88+
for (i <- 0 until k) {
9089
p(i) /= pSum
9190
}
9291
p

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModelEM.scala

Lines changed: 88 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20-
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix}
21-
import breeze.linalg.Transpose
20+
import scala.collection.mutable.IndexedSeq
2221

22+
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag, Transpose}
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
2525
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
2626

27-
import scala.collection.mutable.IndexedSeqView
28-
2927
/**
3028
* This class performs expectation maximization for multivariate Gaussian
3129
* Mixture Models (GMMs). A GMM represents a composite distribution of
@@ -47,87 +45,34 @@ class GaussianMixtureModelEM private (
4745
private var k: Int,
4846
private var convergenceTol: Double,
4947
private var maxIterations: Int) extends Serializable {
50-
51-
// Type aliases for convenience
52-
private type DenseDoubleVector = BreezeVector[Double]
53-
private type DenseDoubleMatrix = BreezeMatrix[Double]
54-
private type VectorArrayView = IndexedSeqView[DenseDoubleVector, Array[DenseDoubleVector]]
55-
56-
private type ExpectationSum = (
57-
Array[Double], // log-likelihood in index 0
58-
Array[Double], // array of weights
59-
Array[DenseDoubleVector], // array of means
60-
Array[DenseDoubleMatrix]) // array of cov matrices
61-
62-
// create a zero'd ExpectationSum instance
63-
private def zeroExpectationSum(k: Int, d: Int): ExpectationSum = {
64-
(Array(0.0),
65-
new Array[Double](k),
66-
(0 until k).map(_ => BreezeVector.zeros[Double](d)).toArray,
67-
(0 until k).map(_ => BreezeMatrix.zeros[Double](d,d)).toArray)
68-
}
6948

70-
// add two ExpectationSum objects (allowed to use modify m1)
71-
// (U, U) => U for aggregation
72-
private def addExpectationSums(m1: ExpectationSum, m2: ExpectationSum): ExpectationSum = {
73-
m1._1(0) += m2._1(0)
74-
var i = 0
75-
while (i < m1._2.length) {
76-
m1._2(i) += m2._2(i)
77-
m1._3(i) += m2._3(i)
78-
m1._4(i) += m2._4(i)
79-
i = i + 1
80-
}
81-
m1
82-
}
49+
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
50+
def this() = this(2, 0.01, 100)
51+
8352

84-
// compute cluster contributions for each input point
85-
// (U, T) => U for aggregation
86-
private def computeExpectation(
87-
weights: Array[Double],
88-
dists: Array[MultivariateGaussian])
89-
(sums: ExpectationSum, x: DenseDoubleVector): ExpectationSum = {
90-
val k = sums._2.length
91-
val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(x) }
92-
val pSum = p.sum
93-
sums._1(0) += math.log(pSum)
94-
val xxt = x * new Transpose(x)
95-
var i = 0
96-
while (i < k) {
97-
p(i) /= pSum
98-
sums._2(i) += p(i)
99-
sums._3(i) += x * p(i)
100-
sums._4(i) += xxt * p(i)
101-
i = i + 1
102-
}
103-
sums
104-
}
10553

10654
// number of samples per cluster to use when initializing Gaussians
10755
private val nSamples = 5
10856

10957
// an initializing GMM can be provided rather than using the
11058
// default random starting point
111-
private var initialGmm: Option[GaussianMixtureModel] = None
112-
113-
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
114-
def this() = this(2, 0.01, 100)
59+
private var initialModel: Option[GaussianMixtureModel] = None
11560

11661
/** Set the initial GMM starting point, bypassing the random initialization.
11762
* You must call setK() prior to calling this method, and the condition
118-
* (gmm.k == this.k) must be met; failure will result in an IllegalArgumentException
63+
* (model.k == this.k) must be met; failure will result in an IllegalArgumentException
11964
*/
120-
def setInitialGmm(gmm: GaussianMixtureModel): this.type = {
121-
if (gmm.k == k) {
122-
initialGmm = Some(gmm)
65+
def setInitialModel(model: GaussianMixtureModel): this.type = {
66+
if (model.k == k) {
67+
initialModel = Some(model)
12368
} else {
124-
throw new IllegalArgumentException("initialing GMM has mismatched cluster count (gmm.k != k)")
69+
throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
12570
}
12671
this
12772
}
12873

12974
/** Return the user supplied initial GMM, if supplied */
130-
def getInitialGmm: Option[GaussianMixtureModel] = initialGmm
75+
def getInitialModel: Option[GaussianMixtureModel] = initialModel
13176

13277
/** Set the number of Gaussians in the mixture model. Default: 2 */
13378
def setK(k: Int): this.type = {
@@ -161,9 +106,6 @@ class GaussianMixtureModelEM private (
161106
*/
162107
def getConvergenceTol: Double = convergenceTol
163108

164-
/** Machine precision value used to ensure matrix conditioning */
165-
private val eps = math.pow(2.0, -52)
166-
167109
/** Perform expectation maximization */
168110
def run(data: RDD[Vector]): GaussianMixtureModel = {
169111
val sc = data.sparkContext
@@ -179,70 +121,122 @@ class GaussianMixtureModelEM private (
179121
// we start with uniform weights, a random mean from the data, and
180122
// diagonal covariance matrices using component variances
181123
// derived from the samples
182-
val (weights, gaussians) = initialGmm match {
183-
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map{ case(mu, sigma) =>
124+
val (weights, gaussians) = initialModel match {
125+
case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>
184126
new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)
185-
}.toArray)
127+
})
186128

187129
case None => {
188130
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
189-
(Array.fill[Double](k)(1.0 / k), (0 until k).map{ i =>
131+
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
190132
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
191133
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
192-
}.toArray)
134+
})
193135
}
194136
}
195137

196138
var llh = Double.MinValue // current log-likelihood
197139
var llhp = 0.0 // previous log-likelihood
198140

199141
var iter = 0
200-
do {
142+
while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) {
201143
// create and broadcast curried cluster contribution function
202-
val compute = sc.broadcast(computeExpectation(weights, gaussians)_)
144+
val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
203145

204146
// aggregate the cluster contribution for all sample points
205-
val (logLikelihood, wSums, muSums, sigmaSums) =
206-
breezeData.aggregate(zeroExpectationSum(k, d))(compute.value, addExpectationSums)
147+
val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
207148

208149
// Create new distributions based on the partial assignments
209150
// (often referred to as the "M" step in literature)
210-
val sumWeights = wSums.sum
211-
for (i <- 0 until k) {
212-
val mu = muSums(i) / wSums(i)
213-
val sigma = sigmaSums(i) / wSums(i) - mu * new Transpose(mu)
214-
weights(i) = wSums(i) / sumWeights
151+
val sumWeights = sums.weights.sum
152+
var i = 0
153+
while (i < k) {
154+
val mu = sums.means(i) / sums.weights(i)
155+
val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use BLAS.dsyr
156+
weights(i) = sums.weights(i) / sumWeights
215157
gaussians(i) = new MultivariateGaussian(mu, sigma)
158+
i = i + 1
216159
}
217160

218161
llhp = llh // current becomes previous
219-
llh = logLikelihood(0) // this is the freshly computed log-likelihood
162+
llh = sums.logLikelihood // this is the freshly computed log-likelihood
220163
iter += 1
221-
} while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol)
164+
}
222165

223166
// Need to convert the breeze matrices to MLlib matrices
224-
val means = (0 until k).map(i => Vectors.fromBreeze(gaussians(i).mu)).toArray
225-
val sigmas = (0 until k).map(i => Matrices.fromBreeze(gaussians(i).sigma)).toArray
167+
val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) }
168+
val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) }
226169
new GaussianMixtureModel(weights, means, sigmas)
227170
}
228171

229172
/** Average of dense breeze vectors */
230-
private def vectorMean(x: VectorArrayView): DenseDoubleVector = {
173+
private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
231174
val v = BreezeVector.zeros[Double](x(0).length)
232175
x.foreach(xi => v += xi)
233-
v / x.length.asInstanceOf[Double]
176+
v / x.length.toDouble
234177
}
235178

236179
/**
237180
* Construct matrix where diagonal entries are element-wise
238181
* variance of input vectors (computes biased variance)
239182
*/
240-
private def initCovariance(x: VectorArrayView): DenseDoubleMatrix = {
183+
private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
241184
val mu = vectorMean(x)
242185
val ss = BreezeVector.zeros[Double](x(0).length)
243-
val cov = BreezeMatrix.eye[Double](ss.length)
244186
x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
245-
(0 until ss.length).foreach(i => cov(i,i) = ss(i) / x.length)
246-
cov
187+
diag(ss / x.length.toDouble)
247188
}
248189
}
190+
191+
// companion class to provide zero constructor for ExpectationSum
192+
private object ExpectationSum {
193+
private val eps = math.pow(2.0, -52)
194+
195+
def zero(k: Int, d: Int): ExpectationSum = {
196+
new ExpectationSum(0.0, Array.fill(k)(0.0),
197+
Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
198+
}
199+
200+
// compute cluster contributions for each input point
201+
// (U, T) => U for aggregation
202+
def add(
203+
weights: Array[Double],
204+
dists: Array[MultivariateGaussian])
205+
(sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
206+
val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(x) }
207+
val pSum = p.sum
208+
sums.logLikelihood += math.log(pSum)
209+
val xxt = x * new Transpose(x)
210+
var i = 0
211+
while (i < sums.k) {
212+
p(i) /= pSum
213+
sums.weights(i) += p(i)
214+
sums.means(i) += x * p(i)
215+
sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
216+
i = i + 1
217+
}
218+
sums
219+
}
220+
}
221+
222+
// Aggregation class for partial expectation results
223+
private class ExpectationSum(
224+
var logLikelihood: Double,
225+
val weights: Array[Double],
226+
val means: Array[BreezeVector[Double]],
227+
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
228+
229+
val k = weights.length
230+
231+
def +=(x: ExpectationSum): ExpectationSum = {
232+
var i = 0
233+
while (i < k) {
234+
weights(i) += x.weights(i)
235+
means(i) += x.means(i)
236+
sigmas(i) += x.sigmas(i)
237+
i = i + 1
238+
}
239+
logLikelihood += x.logLikelihood
240+
this
241+
}
242+
}

mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@
1717

1818
package org.apache.spark.mllib.stat.impl
1919

20-
import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix}
21-
import breeze.linalg.{Transpose, det, pinv}
20+
import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, Transpose, det, pinv}
2221

2322
/**
2423
* Utility class to implement the density function for multivariate Gaussian distribution.
2524
* Breeze provides this functionality, but it requires the Apache Commons Math library,
2625
* so this class is here so-as to not introduce a new dependency in Spark.
2726
*/
2827
private[mllib] class MultivariateGaussian(
29-
val mu: BreezeVector[Double],
30-
val sigma: BreezeMatrix[Double]) extends Serializable {
28+
val mu: DBV[Double],
29+
val sigma: DBM[Double]) extends Serializable {
3130
private val sigmaInv2 = pinv(sigma) * -0.5
3231
private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5)
3332

34-
def pdf(x: BreezeVector[Double]): Double = {
33+
/** Returns density of this multivariate Gaussian at given point, x */
34+
def pdf(x: DBV[Double]): Double = {
3535
val delta = x - mu
3636
val deltaTranspose = new Transpose(delta)
3737
U * math.exp(deltaTranspose * sigmaInv2 * delta)

mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
6565

6666
val gmm = new GaussianMixtureModelEM()
6767
.setK(2)
68-
.setInitialGmm(initialGmm)
68+
.setInitialModel(initialGmm)
6969
.run(data)
7070

7171
assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)

0 commit comments

Comments
 (0)