Skip to content

Commit de24662

Browse files
author
DB Tsai
committed
address feedback
1 parent b185a77 commit de24662

File tree

3 files changed

+16
-30
lines changed

3 files changed

+16
-30
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental
2323
import org.apache.spark.Logging
2424
import org.apache.spark.SparkContext._
2525
import org.apache.spark.mllib.linalg.{Vector, Vectors}
26+
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
2627
import org.apache.spark.mllib.util.MLUtils
2728
import org.apache.spark.rdd.RDD
2829
import org.apache.spark.storage.StorageLevel
@@ -168,15 +169,10 @@ class KMeans private (
168169

169170
// Execute iterations of Lloyd's algorithm until all runs have converged
170171
while (iteration < maxIterations && !activeRuns.isEmpty) {
171-
type WeightedPoint = (Array[Double], Long)
172-
def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
173-
require(p1._1.size == p2._1.size)
174-
var i = 0
175-
while(i < p1._1.size) {
176-
p1._1(i) += p2._1(i)
177-
i += 1
178-
}
179-
(p1._1, p1._2 + p2._2)
172+
type WeightedPoint = (Vector, Long)
173+
def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {
174+
axpy(1.0, x._1, y._1)
175+
(y._1, x._2 + y._2)
180176
}
181177

182178
val activeCenters = activeRuns.map(r => centers(r)).toArray
@@ -191,15 +187,15 @@ class KMeans private (
191187
val k = thisActiveCenters(0).length
192188
val dims = thisActiveCenters(0)(0).vector.size
193189

194-
val sums = Array.fill(runs, k)(Array.ofDim[Double](dims))
190+
val sums = Array.fill(runs, k)(Vectors.zeros(dims))
195191
val counts = Array.fill(runs, k)(0L)
196192

197193
points.foreach { point =>
198194
(0 until runs).foreach { i =>
199195
val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
200196
costAccums(i) += cost
201197
val sum = sums(i)(bestCenter)
202-
point.vector.foreachActive((index, value) => sum(index) += value)
198+
axpy(1.0, point.vector, sum)
203199
counts(i)(bestCenter) += 1
204200
}
205201
}
@@ -217,12 +213,7 @@ class KMeans private (
217213
while (j < k) {
218214
val (sum, count) = totalContribs((i, j))
219215
if (count != 0) {
220-
val size = sum.size
221-
var i = 0
222-
while(i < sum.size) {
223-
sum(i) /= count
224-
i += 1
225-
}
216+
scal(1.0 / count, sum)
226217
val newCenter = new VectorWithNorm(sum)
227218
if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
228219
changed = true

mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.util.Random
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.mllib.linalg.Vectors
24-
import org.apache.spark.mllib.linalg.BLAS.axpy
24+
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
2525

2626
/**
2727
* An utility object to run K-means locally. This is private to the ML package because it's used
@@ -75,12 +75,12 @@ private[mllib] object LocalKMeans extends Logging {
7575
while (moved && iteration < maxIterations) {
7676
moved = false
7777
val counts = Array.fill(k)(0.0)
78-
val sums = Array.fill(k)(Array.ofDim[Double](dimensions))
78+
val sums = Array.fill(k)(Vectors.zeros(dimensions))
7979
var i = 0
8080
while (i < points.length) {
8181
val p = points(i)
8282
val index = KMeans.findClosest(centers, p)._1
83-
axpy(weights(i), p.vector, Vectors.dense(sums(index)))
83+
axpy(weights(i), p.vector, sums(index))
8484
counts(index) += weights(i)
8585
if (index != oldClosest(i)) {
8686
moved = true
@@ -95,14 +95,7 @@ private[mllib] object LocalKMeans extends Logging {
9595
// Assign center to a random point
9696
centers(j) = points(rand.nextInt(points.length)).toDense
9797
} else {
98-
val sum = sums(j)
99-
val count = counts(j)
100-
val size = sum.size
101-
var i = 0
102-
while(i < size) {
103-
sum(i) /= count
104-
i += 1
105-
}
98+
scal(1.0 / counts(j), sums(j))
10699
centers(j) = new VectorWithNorm(sums(j))
107100
}
108101
j += 1

mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,11 @@ object MLUtils {
311311
} else if (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {
312312
val dotValue = dot(v1, v2)
313313
sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
314-
val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) / (sqDist + EPSILON)
314+
val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dotValue)) /
315+
(sqDist + EPSILON)
315316
if (precisionBound2 > precision) {
316-
// TODO: breezeSquaredDistance is slow, so we should replace it with our own implementation.
317+
// TODO: breezeSquaredDistance is slow,
318+
// so we should replace it with our own implementation.
317319
sqDist = breezeSquaredDistance(v1.toBreeze, v2.toBreeze)
318320
}
319321
} else {

0 commit comments

Comments
 (0)