@@ -18,34 +18,20 @@ package org.apache.spark.mllib.rdd
1818
1919import breeze .linalg .{axpy , Vector => BV }
2020
21- import org .apache .spark .mllib .linalg .{ Vector , Vectors }
21+ import org .apache .spark .mllib .linalg .Vector
2222import org .apache .spark .rdd .RDD
2323
2424/**
2525 * Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
2626 * elements count.
2727 */
28- case class VectorRDDStatisticalSummary (
29- mean : Vector ,
30- variance : Vector ,
31- count : Long ,
32- max : Vector ,
33- min : Vector ,
34- nonZeroCnt : Vector ) extends Serializable
35-
36- /**
37- * Case class of the aggregate value for collecting summary statistics from RDD[Vector]. These
38- * values are relatively with
39- * [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary ]], the
40- * latter is computed from the former.
41- */
42- private case class VectorRDDStatisticalRing (
43- fakeMean : BV [Double ],
44- fakeM2n : BV [Double ],
45- totalCnt : Double ,
46- nnz : BV [Double ],
47- fakeMax : BV [Double ],
48- fakeMin : BV [Double ])
28+ case class VectorRDDStatisticalAggregator (
29+ mean : BV [Double ],
30+ statCounter : BV [Double ],
31+ totalCount : Double ,
32+ numNonZeros : BV [Double ],
33+ max : BV [Double ],
34+ min : BV [Double ])
4935
5036/**
5137 * Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector ]] through an
@@ -58,11 +44,12 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5844 * Aggregate function used for aggregating elements in a worker together.
5945 */
6046 private def seqOp (
61- aggregator : VectorRDDStatisticalRing ,
62- currData : BV [Double ]): VectorRDDStatisticalRing = {
47+ aggregator : VectorRDDStatisticalAggregator ,
48+ currData : BV [Double ]): VectorRDDStatisticalAggregator = {
6349 aggregator match {
64- case VectorRDDStatisticalRing (prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
50+ case VectorRDDStatisticalAggregator (prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
6551 currData.activeIterator.foreach {
52+ case (id, 0.0 ) =>
6653 case (id, value) =>
6754 if (maxVec(id) < value) maxVec(id) = value
6855 if (minVec(id) > value) minVec(id) = value
@@ -74,7 +61,7 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
7461 nnzVec(id) += 1.0
7562 }
7663
77- VectorRDDStatisticalRing (
64+ VectorRDDStatisticalAggregator (
7865 prevMean,
7966 prevM2n,
8067 cnt + 1.0 ,
@@ -88,11 +75,11 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8875 * Combine function used for combining intermediate results together from every worker.
8976 */
9077 private def combOp (
91- statistics1 : VectorRDDStatisticalRing ,
92- statistics2 : VectorRDDStatisticalRing ): VectorRDDStatisticalRing = {
78+ statistics1 : VectorRDDStatisticalAggregator ,
79+ statistics2 : VectorRDDStatisticalAggregator ): VectorRDDStatisticalAggregator = {
9380 (statistics1, statistics2) match {
94- case (VectorRDDStatisticalRing (mean1, m2n1, cnt1, nnz1, max1, min1),
95- VectorRDDStatisticalRing (mean2, m2n2, cnt2, nnz2, max2, min2)) =>
81+ case (VectorRDDStatisticalAggregator (mean1, m2n1, cnt1, nnz1, max1, min1),
82+ VectorRDDStatisticalAggregator (mean2, m2n2, cnt2, nnz2, max2, min2)) =>
9683 val totalCnt = cnt1 + cnt2
9784 val deltaMean = mean2 - mean1
9885
@@ -120,51 +107,50 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
120107 }
121108
122109 axpy(1.0 , nnz2, nnz1)
123- VectorRDDStatisticalRing (mean1, m2n1, totalCnt, nnz1, max1, min1)
110+ VectorRDDStatisticalAggregator (mean1, m2n1, totalCnt, nnz1, max1, min1)
124111 }
125112 }
126113
127114 /**
128115 * Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
129116 */
130- def summarizeStatistics (size : Int ): VectorRDDStatisticalSummary = {
131- val zeroValue = VectorRDDStatisticalRing (
117+ def summarizeStatistics (): VectorRDDStatisticalAggregator = {
118+ val size = self.take(1 ).head.size
119+ val zeroValue = VectorRDDStatisticalAggregator (
132120 BV .zeros[Double ](size),
133121 BV .zeros[Double ](size),
134122 0.0 ,
135123 BV .zeros[Double ](size),
136124 BV .fill(size)(Double .MinValue ),
137125 BV .fill(size)(Double .MaxValue ))
138126
139- val VectorRDDStatisticalRing (fakeMean, fakeM2n , totalCnt, nnz, fakeMax, fakeMin ) =
127+ val VectorRDDStatisticalAggregator (currMean, currM2n , totalCnt, nnz, currMax, currMin ) =
140128 self.map(_.toBreeze).aggregate(zeroValue)(seqOp, combOp)
141129
142130 // solve real mean
143- val realMean = fakeMean :* nnz :/ totalCnt
131+ val realMean = currMean :* nnz :/ totalCnt
144132
145133 // solve real m2n
146- val deltaMean = fakeMean
147- val realM2n = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
134+ val deltaMean = currMean
135+ val realM2n = currM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
148136
149137 // remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
150- val max = Vectors .sparse(size, fakeMax.activeIterator.map { case (id, value) =>
151- if ((value == Double .MinValue ) && (realMean(id) != Double .MinValue )) (id, 0.0 )
152- else (id, value)
153- }.toSeq)
154- val min = Vectors .sparse(size, fakeMin.activeIterator.map { case (id, value) =>
155- if ((value == Double .MaxValue ) && (realMean(id) != Double .MaxValue )) (id, 0.0 )
156- else (id, value)
157- }.toSeq)
138+ nnz.activeIterator.foreach {
139+ case (id, 0.0 ) =>
140+ currMax(id) = 0.0
141+ currMin(id) = 0.0
142+ case _ =>
143+ }
158144
159145 // get variance
160146 realM2n :/= totalCnt
161147
162- VectorRDDStatisticalSummary (
163- Vectors .fromBreeze( realMean) ,
164- Vectors .fromBreeze( realM2n) ,
165- totalCnt.toLong ,
166- Vectors .fromBreeze( nnz) ,
167- max ,
168- min )
148+ VectorRDDStatisticalAggregator (
149+ realMean,
150+ realM2n,
151+ totalCnt,
152+ nnz,
153+ currMax ,
154+ currMin )
169155 }
170- }
156+ }
0 commit comments