Skip to content

Commit 18cf072

Browse files
committed
change def to lazy val to make sure that the computations in function be evaluated only once
1 parent f7a3ca2 commit 18cf072

File tree

2 files changed

+38
-27
lines changed

2 files changed

+38
-27
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
1718
package org.apache.spark.mllib.rdd
1819

1920
import breeze.linalg.{axpy, Vector => BV}
@@ -26,12 +27,12 @@ import org.apache.spark.rdd.RDD
2627
* elements count.
2728
*/
2829
trait VectorRDDStatisticalSummary {
29-
def mean(): Vector
30-
def variance(): Vector
31-
def totalCount(): Long
32-
def numNonZeros(): Vector
33-
def max(): Vector
34-
def min(): Vector
30+
def mean: Vector
31+
def variance: Vector
32+
def totalCount: Long
33+
def numNonZeros: Vector
34+
def max: Vector
35+
def min: Vector
3536
}
3637

3738
private class Aggregator(
@@ -42,30 +43,32 @@ private class Aggregator(
4243
val currMax: BV[Double],
4344
val currMin: BV[Double]) extends VectorRDDStatisticalSummary with Serializable {
4445

45-
override def mean(): Vector = {
46-
Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
47-
}
46+
override lazy val mean = Vectors.fromBreeze(currMean :* nnz :/ totalCnt)
4847

49-
override def variance(): Vector = {
48+
override lazy val variance = {
5049
val deltaMean = currMean
51-
val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
52-
realM2n :/= totalCnt
53-
Vectors.fromBreeze(realM2n)
50+
var i = 0
51+
while(i < currM2n.size) {
52+
currM2n(i) -= deltaMean(i) * deltaMean(i) * nnz(i) * (nnz(i)-totalCnt) / totalCnt
53+
currM2n(i) /= totalCnt
54+
i += 1
55+
}
56+
Vectors.fromBreeze(currM2n)
5457
}
5558

56-
override def totalCount(): Long = totalCnt.toLong
59+
override lazy val totalCount: Long = totalCnt.toLong
5760

58-
override def numNonZeros(): Vector = Vectors.fromBreeze(nnz)
61+
override lazy val numNonZeros: Vector = Vectors.fromBreeze(nnz)
5962

60-
override def max(): Vector = {
63+
override lazy val max: Vector = {
6164
nnz.activeIterator.foreach {
6265
case (id, count) =>
6366
if ((count == 0.0) || ((count < totalCnt) && (currMax(id) < 0.0))) currMax(id) = 0.0
6467
}
6568
Vectors.fromBreeze(currMax)
6669
}
6770

68-
override def min(): Vector = {
71+
override lazy val min: Vector = {
6972
nnz.activeIterator.foreach {
7073
case (id, count) =>
7174
if ((count == 0.0) || ((count < totalCnt) && (currMin(id) > 0.0))) currMin(id) = 0.0
@@ -78,6 +81,7 @@ private class Aggregator(
7881
*/
7982
def add(currData: BV[Double]): this.type = {
8083
currData.activeIterator.foreach {
84+
// this case is used for filtering the zero elements if the vector is a dense one.
8185
case (id, 0.0) =>
8286
case (id, value) =>
8387
if (currMax(id) < value) currMax(id) = value
@@ -106,7 +110,8 @@ private class Aggregator(
106110
other.currMean.activeIterator.foreach {
107111
case (id, 0.0) =>
108112
case (id, value) =>
109-
currMean(id) = (currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
113+
currMean(id) =
114+
(currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
110115
}
111116

112117
other.currM2n.activeIterator.foreach {
@@ -157,4 +162,4 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
157162
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
158163
)
159164
}
160-
}
165+
}

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
1718
package org.apache.spark.mllib.rdd
1819

1920
import scala.collection.mutable.ArrayBuffer
2021

2122
import org.scalatest.FunSuite
2223

2324
import org.apache.spark.mllib.linalg.{Vector, Vectors}
25+
import org.apache.spark.mllib.rdd.VectorRDDFunctionsSuite._
2426
import org.apache.spark.mllib.util.LocalSparkContext
2527
import org.apache.spark.mllib.util.MLUtils._
2628

@@ -29,7 +31,6 @@ import org.apache.spark.mllib.util.MLUtils._
2931
* between dense and sparse vector are tested.
3032
*/
3133
class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
32-
import VectorRDDFunctionsSuite._
3334

3435
val localData = Array(
3536
Vectors.dense(1.0, 2.0, 3.0),
@@ -47,16 +48,21 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4748
val (summary, denseTime) =
4849
time(data.summarizeStatistics())
4950

50-
assert(equivVector(summary.mean(), Vectors.dense(4.0, 5.0, 6.0)),
51+
assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)),
5152
"Column mean do not match.")
52-
assert(equivVector(summary.variance(), Vectors.dense(6.0, 6.0, 6.0)),
53+
54+
assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)),
5355
"Column variance do not match.")
54-
assert(summary.totalCount() === 3, "Column cnt do not match.")
55-
assert(equivVector(summary.numNonZeros(), Vectors.dense(3.0, 3.0, 3.0)),
56+
57+
assert(summary.totalCount === 3, "Column cnt do not match.")
58+
59+
assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)),
5660
"Column nnz do not match.")
57-
assert(equivVector(summary.max(), Vectors.dense(7.0, 8.0, 9.0)),
61+
62+
assert(equivVector(summary.max, Vectors.dense(7.0, 8.0, 9.0)),
5863
"Column max do not match.")
59-
assert(equivVector(summary.min(), Vectors.dense(1.0, 2.0, 3.0)),
64+
65+
assert(equivVector(summary.min, Vectors.dense(1.0, 2.0, 3.0)),
6066
"Column min do not match.")
6167

6268
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
@@ -82,4 +88,4 @@ object VectorRDDFunctionsSuite {
8288
val denominator = math.max(lhs, rhs)
8389
math.abs(lhs - rhs) / denominator < 0.3
8490
}
85-
}
91+
}

0 commit comments

Comments
 (0)