1616 */
1717package org .apache .spark .mllib .rdd
1818
19- import breeze .linalg .{Vector => BV , axpy }
19+ import breeze .linalg .{axpy , Vector => BV }
2020
2121import org .apache .spark .mllib .linalg .{Vector , Vectors }
2222import org .apache .spark .rdd .RDD
2323
24+ /**
25+ * Case class of the summary statistics, including mean, variance, count, max, min, and non-zero
26+ * elements count.
27+ */
2428case class VectorRDDStatisticalSummary (
2529 mean : Vector ,
2630 variance : Vector ,
@@ -29,6 +33,12 @@ case class VectorRDDStatisticalSummary(
2933 min : Vector ,
3034 nonZeroCnt : Vector ) extends Serializable
3135
36+ /**
37+ * Case class of the aggregate value for collecting summary statistics from RDD[Vector]. These
38+ * values are relatively with
39+ * [[org.apache.spark.mllib.rdd.VectorRDDStatisticalSummary VectorRDDStatisticalSummary ]], the
40+ * latter is computed from the former.
41+ */
3242private case class VectorRDDStatisticalRing (
3343 fakeMean : BV [Double ],
3444 fakeM2n : BV [Double ],
@@ -45,18 +55,8 @@ private case class VectorRDDStatisticalRing(
4555class VectorRDDFunctions (self : RDD [Vector ]) extends Serializable {
4656
4757 /**
48- * Compute full column-wise statistics for the RDD, including
49- * {{{
50- * Mean: Vector,
51- * Variance: Vector,
52- * Count: Double,
53- * Non-zero count: Vector,
54- * Maximum elements: Vector,
55- * Minimum elements: Vector.
56- * }}},
57- * with the size of Vector as input parameter.
58+ * Aggregate function used for aggregating elements in a worker together.
5859 */
59-
6060 private def seqOp (
6161 aggregator : VectorRDDStatisticalRing ,
6262 currData : BV [Double ]): VectorRDDStatisticalRing = {
@@ -84,6 +84,9 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8484 }
8585 }
8686
87+ /**
88+ * Combine function used for combining intermediate results together from every worker.
89+ */
8790 private def combOp (
8891 statistics1 : VectorRDDStatisticalRing ,
8992 statistics2 : VectorRDDStatisticalRing ): VectorRDDStatisticalRing = {
@@ -92,27 +95,38 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
9295 VectorRDDStatisticalRing (mean2, m2n2, cnt2, nnz2, max2, min2)) =>
9396 val totalCnt = cnt1 + cnt2
9497 val deltaMean = mean2 - mean1
98+
9599 mean2.activeIterator.foreach {
96100 case (id, 0.0 ) =>
97- case (id, value) => mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
101+ case (id, value) =>
102+ mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
98103 }
104+
99105 m2n2.activeIterator.foreach {
100106 case (id, 0.0 ) =>
101- case (id, value) => m2n1(id) += value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+ nnz2(id))
107+ case (id, value) =>
108+ m2n1(id) +=
109+ value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+ nnz2(id))
102110 }
111+
103112 max2.activeIterator.foreach {
104113 case (id, value) =>
105114 if (max1(id) < value) max1(id) = value
106115 }
116+
107117 min2.activeIterator.foreach {
108118 case (id, value) =>
109119 if (min1(id) > value) min1(id) = value
110120 }
121+
111122 axpy(1.0 , nnz2, nnz1)
112123 VectorRDDStatisticalRing (mean1, m2n1, totalCnt, nnz1, max1, min1)
113124 }
114125 }
115126
127+ /**
128+ * Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
129+ */
116130 def summarizeStatistics (size : Int ): VectorRDDStatisticalSummary = {
117131 val zeroValue = VectorRDDStatisticalRing (
118132 BV .zeros[Double ](size),
@@ -122,16 +136,17 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
122136 BV .fill(size)(Double .MinValue ),
123137 BV .fill(size)(Double .MaxValue ))
124138
125- val breezeVectors = self.map(_.toBreeze)
126139 val VectorRDDStatisticalRing (fakeMean, fakeM2n, totalCnt, nnz, fakeMax, fakeMin) =
127- breezeVectors .aggregate(zeroValue)(seqOp, combOp)
140+ self.map(_.toBreeze) .aggregate(zeroValue)(seqOp, combOp)
128141
129142 // solve real mean
130143 val realMean = fakeMean :* nnz :/ totalCnt
131- // solve real variance
132- val deltaMean = fakeMean :- 0.0
133- val realVar = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
134- // max, min
144+
145+ // solve real m2n
146+ val deltaMean = fakeMean
147+ val realM2n = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
148+
149+ // remove the initial value in max and min, i.e. the Double.MaxValue or Double.MinValue.
135150 val max = Vectors .sparse(size, fakeMax.activeIterator.map { case (id, value) =>
136151 if ((value == Double .MinValue ) && (realMean(id) != Double .MinValue )) (id, 0.0 )
137152 else (id, value)
@@ -142,11 +157,11 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
142157 }.toSeq)
143158
144159 // get variance
145- realVar :/= totalCnt
160+ realM2n :/= totalCnt
146161
147162 VectorRDDStatisticalSummary (
148163 Vectors .fromBreeze(realMean),
149- Vectors .fromBreeze(realVar ),
164+ Vectors .fromBreeze(realM2n ),
150165 totalCnt.toLong,
151166 Vectors .fromBreeze(nnz),
152167 max,
0 commit comments