@@ -19,16 +19,15 @@ package org.apache.spark.mllib.clustering
1919
2020import scala .reflect .ClassTag
2121
22- import breeze .linalg .{Vector => BV }
23-
24- import org .apache .spark .annotation .DeveloperApi
2522import org .apache .spark .Logging
26- import org .apache .spark .mllib .linalg .{Vectors , Vector }
27- import org .apache .spark .rdd .RDD
2823import org .apache .spark .SparkContext ._
29- import org .apache .spark .streaming .dstream .DStream
24+ import org .apache .spark .annotation .DeveloperApi
25+ import org .apache .spark .mllib .linalg .{BLAS , Vector , Vectors }
26+ import org .apache .spark .rdd .RDD
3027import org .apache .spark .streaming .StreamingContext ._
28+ import org .apache .spark .streaming .dstream .DStream
3129import org .apache .spark .util .Utils
30+ import org .apache .spark .util .random .XORShiftRandom
3231
3332/**
3433 * :: DeveloperApi ::
@@ -66,55 +65,81 @@ import org.apache.spark.util.Utils
6665@ DeveloperApi
6766class StreamingKMeansModel (
6867 override val clusterCenters : Array [Vector ],
69- val clusterCounts : Array [Long ]) extends KMeansModel (clusterCenters) with Logging {
68+ val clusterWeights : Array [Double ]) extends KMeansModel (clusterCenters) with Logging {
7069
7170 /** Perform a k-means update on a batch of data. */
7271 def update (data : RDD [Vector ], decayFactor : Double , timeUnit : String ): StreamingKMeansModel = {
7372
74- val centers = clusterCenters
75- val counts = clusterCounts
76-
7773 // find nearest cluster to each point
78- val closest = data.map(point => (this .predict(point), (point.toBreeze, 1 .toLong )))
74+ val closest = data.map(point => (this .predict(point), (point, 1L )))
7975
8076 // get sums and counts for updating each cluster
81- type WeightedPoint = ( BV [ Double ] , Long )
82- def mergeContribs ( p1 : WeightedPoint , p2 : WeightedPoint ) : WeightedPoint = {
83- (p1._1 += p2._1 , p1._2 + p2._2)
77+ val mergeContribs : (( Vector , Long ), ( Vector , Long )) => ( Vector , Long ) = (p1, p2) => {
78+ BLAS .axpy( 1.0 , p2._1, p1._1)
79+ (p1._1, p1._2 + p2._2)
8480 }
85- val pointStats : Array [(Int , (BV [Double ], Long ))] =
86- closest.reduceByKey(mergeContribs).collect()
81+ val dim = clusterCenters(0 ).size
82+ val pointStats : Array [(Int , (Vector , Long ))] = closest
83+ .aggregateByKey((Vectors .zeros(dim), 0L ))(mergeContribs, mergeContribs)
84+ .collect()
85+
86+ val discount = timeUnit match {
87+ case StreamingKMeans .BATCHES => decayFactor
88+ case StreamingKMeans .POINTS =>
89+ val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
90+ n
91+ }.sum
92+ math.pow(decayFactor, numNewPoints)
93+ }
94+
95+ // apply discount to weights
96+ BLAS .scal(discount, Vectors .dense(clusterWeights))
8797
8898 // implement update rule
89- pointStats.foreach { case (label, (mean, count)) =>
90- // store old count and centroid
91- val oldCount = counts(label)
92- val oldCentroid = centers(label).toBreeze
93- // get new count and centroid
94- val newCount = count
95- val newCentroid = mean / newCount.toDouble
96- // compute the normalized scale factor that controls forgetting
97- val lambda = timeUnit match {
98- case " batches" => newCount / (decayFactor * oldCount + newCount)
99- case " points" => newCount / (math.pow(decayFactor, newCount) * oldCount + newCount)
100- }
101- // perform the update
102- val updatedCentroid = oldCentroid + (newCentroid - oldCentroid) * lambda
103- // store the new counts and centers
104- counts(label) = oldCount + newCount
105- centers(label) = Vectors .fromBreeze(updatedCentroid)
99+ pointStats.foreach { case (label, (sum, count)) =>
100+ val centroid = clusterCenters(label)
101+
102+ val updatedWeight = clusterWeights(label) + count
103+ val lambda = count / math.max(updatedWeight, 1e-16 )
104+
105+ clusterWeights(label) = updatedWeight
106+ BLAS .scal(1.0 - lambda, centroid)
107+ BLAS .axpy(lambda / count, sum, centroid)
106108
107109 // display the updated cluster centers
108- val display = centers(label).size match {
109- case x if x > 100 => centers(label).toArray.take(100 ).mkString(" [" , " ," , " ..." )
110- case _ => centers(label).toArray.mkString(" [" , " ," , " ]" )
110+ val display = clusterCenters(label).size match {
111+ case x if x > 100 => centroid.toArray.take(100 ).mkString(" [" , " ," , " ..." )
112+ case _ => centroid.toArray.mkString(" [" , " ," , " ]" )
113+ }
114+
115+ logInfo(s " Cluster $label updated with weight $updatedWeight and centroid: $display" )
116+ }
117+
118+ // Check whether the smallest cluster is dying. If so, split the largest cluster.
119+ val weightsWithIndex = clusterWeights.view.zipWithIndex
120+ val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
121+ val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
122+ if (minWeight < 1e-8 * maxWeight) {
123+ logInfo(s " Cluster $smallest is dying. Split the largest cluster $largest into two. " )
124+ val weight = (maxWeight + minWeight) / 2.0
125+ clusterWeights(largest) = weight
126+ clusterWeights(smallest) = weight
127+ val largestClusterCenter = clusterCenters(largest)
128+ val smallestClusterCenter = clusterCenters(smallest)
129+ var j = 0
130+ while (j < dim) {
131+ val x = largestClusterCenter(j)
132+ val p = 1e-14 * math.max(math.abs(x), 1.0 )
133+ largestClusterCenter.toBreeze(j) = x + p
134+ smallestClusterCenter.toBreeze(j) = x - p
135+ j += 1
111136 }
112- logInfo(" Cluster %d updated: %s " .format (label, display))
113137 }
114- new StreamingKMeansModel (centers, counts)
115- }
116138
139+ this
140+ }
117141}
142+
118143/**
119144 * :: DeveloperApi ::
120145 * StreamingKMeans provides methods for configuring a
@@ -128,7 +153,7 @@ class StreamingKMeansModel(
128153 * val model = new StreamingKMeans()
129154 * .setDecayFactor(0.5)
130155 * .setK(3)
131- * .setRandomCenters(5)
156+ * .setRandomCenters(5, 100.0 )
132157 * .trainOn(DStream)
133158 */
134159@ DeveloperApi
@@ -137,9 +162,9 @@ class StreamingKMeans(
137162 var decayFactor : Double ,
138163 var timeUnit : String ) extends Logging {
139164
140- protected var model : StreamingKMeansModel = new StreamingKMeansModel ( null , null )
165+ def this () = this ( 2 , 1.0 , StreamingKMeans . BATCHES )
141166
142- def this () = this ( 2 , 1.0 , " batches " )
167+ protected var model : StreamingKMeansModel = new StreamingKMeansModel ( null , null )
143168
144169 /** Set the number of clusters. */
145170 def setK (k : Int ): this .type = {
@@ -155,7 +180,7 @@ class StreamingKMeans(
155180
156181 /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
157182 def setHalfLife (halfLife : Double , timeUnit : String ): this .type = {
158- if (timeUnit != " batches " && timeUnit != " points " ) {
183+ if (timeUnit != StreamingKMeans . BATCHES && timeUnit != StreamingKMeans . POINTS ) {
159184 throw new IllegalArgumentException (" Invalid time unit for decay: " + timeUnit)
160185 }
161186 this .decayFactor = math.exp(math.log(0.5 ) / halfLife)
@@ -165,26 +190,23 @@ class StreamingKMeans(
165190 }
166191
167192 /** Specify initial centers directly. */
168- def setInitialCenters (initialCenters : Array [Vector ]): this .type = {
169- val clusterCounts = new Array [Long ](this .k)
170- this .model = new StreamingKMeansModel (initialCenters, clusterCounts)
193+ def setInitialCenters (centers : Array [Vector ], weights : Array [Double ]): this .type = {
194+ model = new StreamingKMeansModel (centers, weights)
171195 this
172196 }
173197
174- /** Initialize random centers, requiring only the number of dimensions.
175- *
176- * @param dim Number of dimensions
177- * @param seed Random seed
178- * */
179- def setRandomCenters (dim : Int , seed : Long = Utils .random.nextLong): this .type = {
180-
181- val random = Utils .random
182- random.setSeed(seed)
183-
184- val initialCenters = (0 until k)
185- .map(_ => Vectors .dense(Array .fill(dim)(random.nextGaussian()))).toArray
186- val clusterCounts = new Array [Long ](this .k)
187- this .model = new StreamingKMeansModel (initialCenters, clusterCounts)
198+ /**
199+ * Initialize random centers, requiring only the number of dimensions.
200+ *
201+ * @param dim Number of dimensions
202+ * @param weight Weight for each center
203+ * @param seed Random seed
204+ */
205+ def setRandomCenters (dim : Int , weight : Double , seed : Long = Utils .random.nextLong): this .type = {
206+ val random = new XORShiftRandom (seed)
207+ val centers = Array .fill(k)(Vectors .dense(Array .fill(dim)(random.nextGaussian())))
208+ val weights = Array .fill(k)(weight)
209+ model = new StreamingKMeansModel (centers, weights)
188210 this
189211 }
190212
@@ -202,9 +224,9 @@ class StreamingKMeans(
202224 * @param data DStream containing vector data
203225 */
204226 def trainOn (data : DStream [Vector ]) {
205- this . assertInitialized()
227+ assertInitialized()
206228 data.foreachRDD { (rdd, time) =>
207- model = model.update(rdd, this . decayFactor, this . timeUnit)
229+ model = model.update(rdd, decayFactor, timeUnit)
208230 }
209231 }
210232
@@ -215,7 +237,7 @@ class StreamingKMeans(
215237 * @return DStream containing predictions
216238 */
217239 def predictOn (data : DStream [Vector ]): DStream [Int ] = {
218- this . assertInitialized()
240+ assertInitialized()
219241 data.map(model.predict)
220242 }
221243
@@ -227,16 +249,20 @@ class StreamingKMeans(
227249 * @return DStream containing the input keys and the predictions as values
228250 */
229251 def predictOnValues [K : ClassTag ](data : DStream [(K , Vector )]): DStream [(K , Int )] = {
230- this . assertInitialized()
252+ assertInitialized()
231253 data.mapValues(model.predict)
232254 }
233255
234256 /** Check whether cluster centers have been initialized. */
235- def assertInitialized (): Unit = {
236- if (Option ( model.clusterCenters) == None ) {
257+ private [ this ] def assertInitialized (): Unit = {
258+ if (model.clusterCenters == null ) {
237259 throw new IllegalStateException (
238260 " Initial cluster centers must be set before starting predictions" )
239261 }
240262 }
263+ }
241264
265+ private [clustering] object StreamingKMeans {
266+ final val BATCHES = " batches"
267+ final val POINTS = " points"
242268}
0 commit comments