Skip to content

Commit 64f3175

Browse files
DB Tsaimengxr
authored andcommitted
[SPARK-4611][MLlib] Implement the efficient vector norm
The vector norm in breeze is implemented by `activeIterator` which is known to be very slow. In this PR, an efficient vector norm is implemented, and with this API, `Normalizer` and `k-means` have big performance improvement. Here is the benchmark against mnist8m dataset. a) `Normalizer` Before DenseVector: 68.25secs SparseVector: 17.01secs With this PR DenseVector: 12.71secs SparseVector: 2.73secs b) `k-means` Before DenseVector: 83.46secs SparseVector: 61.60secs With this PR DenseVector: 70.04secs SparseVector: 59.05secs Author: DB Tsai <[email protected]> Closes #3462 from dbtsai/norm and squashes the following commits: 63c7165 [DB Tsai] typo 0c3637f [DB Tsai] add import org.apache.spark.SparkContext._ back 6fa616c [DB Tsai] address feedback 9b7cb56 [DB Tsai] move norm to static method 0b632e6 [DB Tsai] kmeans dbed124 [DB Tsai] style c1a877c [DB Tsai] first commit
1 parent b0a46d8 commit 64f3175

File tree

4 files changed

+79
-6
lines changed

4 files changed

+79
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22-
import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm}
22+
import breeze.linalg.{DenseVector => BDV, Vector => BV}
2323

2424
import org.apache.spark.annotation.Experimental
2525
import org.apache.spark.Logging
@@ -125,7 +125,7 @@ class KMeans private (
125125
}
126126

127127
// Compute squared norms and cache them.
128-
val norms = data.map(v => breezeNorm(v.toBreeze, 2.0))
128+
val norms = data.map(Vectors.norm(_, 2.0))
129129
norms.persist()
130130
val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) =>
131131
new BreezeVectorWithNorm(v, norm)
@@ -425,7 +425,7 @@ object KMeans {
425425
private[clustering]
426426
class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable {
427427

428-
def this(vector: BV[Double]) = this(vector, breezeNorm(vector, 2.0))
428+
def this(vector: BV[Double]) = this(vector, Vectors.norm(Vectors.fromBreeze(vector), 2.0))
429429

430430
def this(array: Array[Double]) = this(new BDV[Double](array))
431431

mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.mllib.feature
1919

20-
import breeze.linalg.{norm => brzNorm}
21-
2220
import org.apache.spark.annotation.Experimental
2321
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2422

@@ -47,7 +45,7 @@ class Normalizer(p: Double) extends VectorTransformer {
4745
* @return normalized vector. If the norm of the input is zero, it will return the input vector.
4846
*/
4947
override def transform(vector: Vector): Vector = {
50-
val norm = brzNorm(vector.toBreeze, p)
48+
val norm = Vectors.norm(vector, p)
5149

5250
if (norm != 0.0) {
5351
// For dense vector, we've to allocate new memory for new output vector.

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,57 @@ object Vectors {
261261
sys.error("Unsupported Breeze vector type: " + v.getClass.getName)
262262
}
263263
}
264+
265+
/**
266+
* Returns the p-norm of this vector.
267+
* @param vector input vector.
268+
* @param p norm.
269+
* @return norm in L^p^ space.
270+
*/
271+
private[spark] def norm(vector: Vector, p: Double): Double = {
272+
require(p >= 1.0)
273+
val values = vector match {
274+
case dv: DenseVector => dv.values
275+
case sv: SparseVector => sv.values
276+
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
277+
}
278+
val size = values.size
279+
280+
if (p == 1) {
281+
var sum = 0.0
282+
var i = 0
283+
while (i < size) {
284+
sum += math.abs(values(i))
285+
i += 1
286+
}
287+
sum
288+
} else if (p == 2) {
289+
var sum = 0.0
290+
var i = 0
291+
while (i < size) {
292+
sum += values(i) * values(i)
293+
i += 1
294+
}
295+
math.sqrt(sum)
296+
} else if (p == Double.PositiveInfinity) {
297+
var max = 0.0
298+
var i = 0
299+
while (i < size) {
300+
val value = math.abs(values(i))
301+
if (value > max) max = value
302+
i += 1
303+
}
304+
max
305+
} else {
306+
var sum = 0.0
307+
var i = 0
308+
while (i < size) {
309+
sum += math.pow(math.abs(values(i)), p)
310+
i += 1
311+
}
312+
math.pow(sum, 1.0 / p)
313+
}
314+
}
264315
}
265316

266317
/**

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import breeze.linalg.{DenseMatrix => BDM}
2121
import org.scalatest.FunSuite
2222

2323
import org.apache.spark.SparkException
24+
import org.apache.spark.mllib.util.TestingUtils._
2425

2526
class VectorsSuite extends FunSuite {
2627

@@ -197,4 +198,27 @@ class VectorsSuite extends FunSuite {
197198
assert(svMap.get(2) === Some(3.1))
198199
assert(svMap.get(3) === Some(0.0))
199200
}
201+
202+
test("vector p-norm") {
203+
val dv = Vectors.dense(0.0, -1.2, 3.1, 0.0, -4.5, 1.9)
204+
val sv = Vectors.sparse(6, Seq((1, -1.2), (2, 3.1), (3, 0.0), (4, -4.5), (5, 1.9)))
205+
206+
assert(Vectors.norm(dv, 1.0) ~== dv.toArray.foldLeft(0.0)((a, v) =>
207+
a + math.abs(v)) relTol 1E-8)
208+
assert(Vectors.norm(sv, 1.0) ~== sv.toArray.foldLeft(0.0)((a, v) =>
209+
a + math.abs(v)) relTol 1E-8)
210+
211+
assert(Vectors.norm(dv, 2.0) ~== math.sqrt(dv.toArray.foldLeft(0.0)((a, v) =>
212+
a + v * v)) relTol 1E-8)
213+
assert(Vectors.norm(sv, 2.0) ~== math.sqrt(sv.toArray.foldLeft(0.0)((a, v) =>
214+
a + v * v)) relTol 1E-8)
215+
216+
assert(Vectors.norm(dv, Double.PositiveInfinity) ~== dv.toArray.map(math.abs).max relTol 1E-8)
217+
assert(Vectors.norm(sv, Double.PositiveInfinity) ~== sv.toArray.map(math.abs).max relTol 1E-8)
218+
219+
assert(Vectors.norm(dv, 3.7) ~== math.pow(dv.toArray.foldLeft(0.0)((a, v) =>
220+
a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8)
221+
assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) =>
222+
a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8)
223+
}
200224
}

0 commit comments

Comments
 (0)