1717
1818package org .apache .spark .mllib .rdd
1919
20- import breeze .linalg .{axpy , Vector => BV }
20+ import breeze .linalg .{Vector => BV , DenseVector => BDV }
2121
2222import org .apache .spark .mllib .linalg .{Vectors , Vector }
2323import org .apache .spark .rdd .RDD
@@ -29,60 +29,59 @@ import org.apache.spark.rdd.RDD
2929trait VectorRDDStatisticalSummary {
3030 def mean : Vector
3131 def variance : Vector
32- def totalCount : Long
32+ def count : Long
3333 def numNonZeros : Vector
3434 def max : Vector
3535 def min : Vector
3636}
3737
3838/**
3939 * Aggregates [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary ]]
40- * together with add() and merge() function.
40+ * together with add() and merge() function. Online variance solution used in add() function, while
41+ * parallel variance solution used in merge() function. Reference here:
42+ * [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki ]]. Solution here
43+ * ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm to
44+ * O(nnz). Real variance is computed here after we get other statistics, simply by another parallel
45+ * combination process.
4146 */
42- private class Aggregator (
43- val currMean : BV [Double ],
44- val currM2n : BV [Double ],
47+ private class VectorRDDStatisticsAggregator (
48+ val currMean : BDV [Double ],
49+ val currM2n : BDV [Double ],
4550 var totalCnt : Double ,
46- val nnz : BV [Double ],
47- val currMax : BV [Double ],
48- val currMin : BV [Double ]) extends VectorRDDStatisticalSummary with Serializable {
51+ val nnz : BDV [Double ],
52+ val currMax : BDV [Double ],
53+ val currMin : BDV [Double ]) extends VectorRDDStatisticalSummary with Serializable {
4954
5055 // lazy val is used for computing only once time. Same below.
5156 override lazy val mean = Vectors .fromBreeze(currMean :* nnz :/ totalCnt)
5257
53- // Online variance solution used in add() function, while parallel variance solution used in
54- // merge() function. Reference here:
55- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
56- // Solution here ignoring the zero elements when calling add() and merge(), for decreasing the
57- // O(n) algorithm to O(nnz). Real variance is computed here after we get other statistics, simply
58- // by another parallel combination process.
5958 override lazy val variance = {
6059 val deltaMean = currMean
6160 var i = 0
62- while (i < currM2n.size) {
63- currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt- nnz(i)) / totalCnt
61+ while (i < currM2n.size) {
62+ currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
6463 currM2n(i) /= totalCnt
6564 i += 1
6665 }
6766 Vectors .fromBreeze(currM2n)
6867 }
6968
70- override lazy val totalCount : Long = totalCnt.toLong
69+ override lazy val count : Long = totalCnt.toLong
7170
7271 override lazy val numNonZeros : Vector = Vectors .fromBreeze(nnz)
7372
7473 override lazy val max : Vector = {
7574 nnz.iterator.foreach {
7675 case (id, count) =>
77- if ((count == 0.0 ) || ((count < totalCnt) && (currMax(id) < 0.0 ) )) currMax(id) = 0.0
76+ if ((count < totalCnt) && (currMax(id) < 0.0 )) currMax(id) = 0.0
7877 }
7978 Vectors .fromBreeze(currMax)
8079 }
8180
8281 override lazy val min : Vector = {
8382 nnz.iterator.foreach {
8483 case (id, count) =>
85- if ((count == 0.0 ) || ((count < totalCnt) && (currMin(id) > 0.0 ) )) currMin(id) = 0.0
84+ if ((count < totalCnt) && (currMin(id) > 0.0 )) currMin(id) = 0.0
8685 }
8786 Vectors .fromBreeze(currMin)
8887 }
@@ -92,7 +91,7 @@ private class Aggregator(
9291 */
9392 def add (currData : BV [Double ]): this .type = {
9493 currData.activeIterator.foreach {
95- // this case is used for filtering the zero elements if the vector is a dense one .
94+ // this case is used for filtering the zero elements if the vector.
9695 case (id, 0.0 ) =>
9796 case (id, value) =>
9897 if (currMax(id) < value) currMax(id) = value
@@ -112,7 +111,7 @@ private class Aggregator(
112111 /**
113112 * Combine function used for combining intermediate results together from every worker.
114113 */
115- def merge (other : Aggregator ): this .type = {
114+ def merge (other : VectorRDDStatisticsAggregator ): this .type = {
116115
117116 totalCnt += other.totalCnt
118117
@@ -145,7 +144,7 @@ private class Aggregator(
145144 if (currMin(id) > value) currMin(id) = value
146145 }
147146
148- axpy( 1.0 , other.nnz, nnz)
147+ nnz += other.nnz
149148 this
150149 }
151150}
@@ -160,18 +159,18 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
160159 /**
161160 * Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
162161 */
163- def summarizeStatistics (): VectorRDDStatisticalSummary = {
164- val size = self.take( 1 ).head .size
162+ def computeSummaryStatistics (): VectorRDDStatisticalSummary = {
163+ val size = self.first() .size
165164
166- val zeroValue = new Aggregator (
167- BV .zeros[Double ](size),
168- BV .zeros[Double ](size),
165+ val zeroValue = new VectorRDDStatisticsAggregator (
166+ BDV .zeros[Double ](size),
167+ BDV .zeros[Double ](size),
169168 0.0 ,
170- BV .zeros[Double ](size),
171- BV .fill(size)(Double .MinValue ),
172- BV .fill(size)(Double .MaxValue ))
169+ BDV .zeros[Double ](size),
170+ BDV .fill(size)(Double .MinValue ),
171+ BDV .fill(size)(Double .MaxValue ))
173172
174- self.map(_.toBreeze).aggregate[Aggregator ](zeroValue)(
173+ self.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator ](zeroValue)(
175174 (aggregator, data) => aggregator.add(data),
176175 (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
177176 )
0 commit comments