Skip to content

Commit aaa8f25

Browse files
committed
MLUtils: changed privacy of EPSILON from [util] to [mllib]
GaussianMixtureEM: Renamed from GaussianMixtureModelEM; corrected formatting issues GaussianMixtureModel: Renamed predictLabels() to predict() Others: Modifications based on rename of GaussianMixtureEM
1 parent 709e4bf commit aaa8f25

File tree

5 files changed

+16
-19
lines changed

5 files changed

+16
-19
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.examples.mllib
1919

2020
import org.apache.spark.{SparkConf, SparkContext}
21-
import org.apache.spark.mllib.clustering.GaussianMixtureModelEM
21+
import org.apache.spark.mllib.clustering.GaussianMixtureEM
2222
import org.apache.spark.mllib.linalg.Vectors
2323

2424
/**
@@ -46,7 +46,7 @@ object DenseGmmEM {
4646
Vectors.dense(line.trim.split(' ').map(_.toDouble))
4747
}.cache()
4848

49-
val clusters = new GaussianMixtureModelEM()
49+
val clusters = new GaussianMixtureEM()
5050
.setK(k)
5151
.setConvergenceTol(convergenceTol)
5252
.setMaxIterations(maxIterations)
@@ -58,7 +58,7 @@ object DenseGmmEM {
5858
}
5959

6060
println("Cluster labels (first <= 100):")
61-
val clusterLabels = clusters.predictLabels(data)
61+
val clusterLabels = clusters.predict(data)
6262
clusterLabels.take(100).foreach { x =>
6363
print(" " + x)
6464
}
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix,
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
2525
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
26+
import org.apache.spark.mllib.util.MLUtils
2627

2728
/**
2829
* This class performs expectation maximization for multivariate Gaussian
@@ -41,16 +42,14 @@ import org.apache.spark.mllib.stat.impl.MultivariateGaussian
4142
* is considered to have occurred.
4243
* @param maxIterations The maximum number of iterations to perform
4344
*/
44-
class GaussianMixtureModelEM private (
45+
class GaussianMixtureEM private (
4546
private var k: Int,
4647
private var convergenceTol: Double,
4748
private var maxIterations: Int) extends Serializable {
4849

4950
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
5051
def this() = this(2, 0.01, 100)
5152

52-
53-
5453
// number of samples per cluster to use when initializing Gaussians
5554
private val nSamples = 5
5655

@@ -190,8 +189,6 @@ class GaussianMixtureModelEM private (
190189

191190
// companion class to provide zero constructor for ExpectationSum
192191
private object ExpectationSum {
193-
private val eps = math.pow(2.0, -52)
194-
195192
def zero(k: Int, d: Int): ExpectationSum = {
196193
new ExpectationSum(0.0, Array.fill(k)(0.0),
197194
Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
@@ -203,7 +200,9 @@ private object ExpectationSum {
203200
weights: Array[Double],
204201
dists: Array[MultivariateGaussian])
205202
(sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
206-
val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(x) }
203+
val p = weights.zip(dists).map {
204+
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
205+
}
207206
val pSum = p.sum
208207
sums.logLikelihood += math.log(pSum)
209208
val xxt = x * new Transpose(x)

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import breeze.linalg.{DenseVector => BreezeVector}
2222
import org.apache.spark.rdd.RDD
2323
import org.apache.spark.mllib.linalg.{Matrix, Vector}
2424
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
25+
import org.apache.spark.mllib.util.MLUtils
2526

2627
/**
2728
* Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points
@@ -43,7 +44,7 @@ class GaussianMixtureModel(
4344
def k: Int = weight.length
4445

4546
/** Maps given points to their cluster indices. */
46-
def predictLabels(points: RDD[Vector]): RDD[Int] = {
47+
def predict(points: RDD[Vector]): RDD[Int] = {
4748
val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k)
4849
responsibilityMatrix.map(r => r.indexOf(r.max))
4950
}
@@ -70,11 +71,6 @@ class GaussianMixtureModel(
7071
}
7172
}
7273

73-
// We use "eps" as the minimum likelihood density for any given point
74-
// in every cluster; this prevents any divide by zero conditions for
75-
// outlier points.
76-
private val eps = math.pow(2.0, -52)
77-
7874
/**
7975
* Compute the partial assignments for each vector
8076
*/
@@ -83,7 +79,9 @@ class GaussianMixtureModel(
8379
dists: Array[MultivariateGaussian],
8480
weights: Array[Double],
8581
k: Int): Array[Double] = {
86-
val p = weights.zip(dists).map { case (weight, dist) => eps + weight * dist.pdf(pt) }
82+
val p = weights.zip(dists).map {
83+
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt)
84+
}
8785
val pSum = p.sum
8886
for (i <- 0 until k) {
8987
p(i) /= pSum

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.streaming.dstream.DStream
3838
*/
3939
object MLUtils {
4040

41-
private[util] lazy val EPSILON = {
41+
private[mllib] lazy val EPSILON = {
4242
var eps = 1.0
4343
while ((1.0 + (eps / 2.0)) != 1.0) {
4444
eps /= 2.0

mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
3636
val Emu = Vectors.dense(5.0, 10.0)
3737
val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
3838

39-
val gmm = new GaussianMixtureModelEM().setK(1).run(data)
39+
val gmm = new GaussianMixtureEM().setK(1).run(data)
4040

4141
assert(gmm.weight(0) ~== Ew absTol 1E-5)
4242
assert(gmm.mu(0) ~== Emu absTol 1E-5)
@@ -63,7 +63,7 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
6363
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
6464
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
6565

66-
val gmm = new GaussianMixtureModelEM()
66+
val gmm = new GaussianMixtureEM()
6767
.setK(2)
6868
.setInitialModel(initialGmm)
6969
.run(data)

0 commit comments

Comments
 (0)