@@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
1919
2020import java .util
2121
22- import breeze .linalg .{DenseMatrix => BDM , DenseVector => BDV , svd => brzSvd }
22+ import breeze .linalg .{Vector => BV , DenseMatrix => BDM , DenseVector => BDV , svd => brzSvd }
2323import breeze .numerics .{sqrt => brzSqrt }
2424import com .github .fommil .netlib .BLAS .{getInstance => blas }
2525
@@ -29,7 +29,171 @@ import org.apache.spark.rdd.RDD
2929import org .apache .spark .Logging
3030
3131/**
32- * :: Experimental ::
32+ * Trait of the summary statistics, including mean, variance, count, max, min, and non-zero elements
33+ * count.
34+ */
35+ trait VectorRDDStatisticalSummary {
36+
37+ /**
38+ * Computes the mean of columns in RDD[Vector].
39+ */
40+ def mean : Vector
41+
42+ /**
43+ * Computes the sample variance of columns in RDD[Vector].
44+ */
45+ def variance : Vector
46+
47+ /**
48+ * Computes number of vectors in RDD[Vector].
49+ */
50+ def count : Long
51+
52+ /**
53+ * Computes the number of non-zero elements in each column of RDD[Vector].
54+ */
55+ def numNonZeros : Vector
56+
57+ /**
58+ * Computes the maximum of each column in RDD[Vector].
59+ */
60+ def max : Vector
61+
62+ /**
63+ * Computes the minimum of each column in RDD[Vector].
64+ */
65+ def min : Vector
66+ }
67+
68+
69+ /**
70+ * Aggregates [[org.apache.spark.mllib.linalg.distributed.VectorRDDStatisticalSummary
71+ * VectorRDDStatisticalSummary]] together with add() and merge() function. Online variance solution
72+ * used in add() function, while parallel variance solution used in merge() function. Reference here
73+ * : [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki ]]. Solution
74+ * here ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm
75+ * to O(nnz). Real variance is computed here after we get other statistics, simply by another
76+ * parallel combination process.
77+ */
78+ private class VectorRDDStatisticsAggregator (
79+ val currMean : BDV [Double ],
80+ val currM2n : BDV [Double ],
81+ var totalCnt : Double ,
82+ val nnz : BDV [Double ],
83+ val currMax : BDV [Double ],
84+ val currMin : BDV [Double ])
85+ extends VectorRDDStatisticalSummary with Serializable {
86+
87+ override def mean = {
88+ val realMean = BDV .zeros[Double ](currMean.length)
89+ var i = 0
90+ while (i < currMean.length) {
91+ realMean(i) = currMean(i) * nnz(i) / totalCnt
92+ i += 1
93+ }
94+ Vectors .fromBreeze(realMean)
95+ }
96+
97+ override def variance = {
98+ val realVariance = BDV .zeros[Double ](currM2n.length)
99+
100+ val denominator = totalCnt - 1.0
101+
102+ // Sample variance is computed, if the denominator is 0, the variance is just 0.
103+ if (denominator != 0.0 ) {
104+ val deltaMean = currMean
105+ var i = 0
106+ while (i < currM2n.size) {
107+ realVariance(i) =
108+ currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
109+ realVariance(i) /= denominator
110+ i += 1
111+ }
112+ }
113+
114+ Vectors .fromBreeze(realVariance)
115+ }
116+
117+ override def count : Long = totalCnt.toLong
118+
119+ override def numNonZeros : Vector = Vectors .fromBreeze(nnz)
120+
121+ override def max : Vector = {
122+ var i = 0
123+ while (i < nnz.length) {
124+ if ((nnz(i) < totalCnt) && (currMax(i) < 0.0 )) currMax(i) = 0.0
125+ i += 1
126+ }
127+ Vectors .fromBreeze(currMax)
128+ }
129+
130+ override def min : Vector = {
131+ var i = 0
132+ while (i < nnz.length) {
133+ if ((nnz(i) < totalCnt) && (currMin(i) > 0.0 )) currMin(i) = 0.0
134+ i += 1
135+ }
136+ Vectors .fromBreeze(currMin)
137+ }
138+
139+ /**
140+ * Aggregate function used for aggregating elements in a worker together.
141+ */
142+ def add (currData : BV [Double ]): this .type = {
143+ currData.activeIterator.foreach {
144+ // this case is used for filtering the zero elements if the vector.
145+ case (id, 0.0 ) =>
146+ case (id, value) =>
147+ if (currMax(id) < value) currMax(id) = value
148+ if (currMin(id) > value) currMin(id) = value
149+
150+ val tmpPrevMean = currMean(id)
151+ currMean(id) = (currMean(id) * nnz(id) + value) / (nnz(id) + 1.0 )
152+ currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
153+
154+ nnz(id) += 1.0
155+ }
156+
157+ totalCnt += 1.0
158+ this
159+ }
160+
161+ /**
162+ * Combine function used for combining intermediate results together from every worker.
163+ */
164+ def merge (other : VectorRDDStatisticsAggregator ): this .type = {
165+
166+ totalCnt += other.totalCnt
167+
168+ val deltaMean = currMean - other.currMean
169+
170+ var i = 0
171+ while (i < other.currMean.length) {
172+ // merge mean together
173+ if (other.currMean(i) != 0.0 ) {
174+ currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
175+ (nnz(i) + other.nnz(i))
176+ }
177+
178+ // merge m2n together
179+ if (nnz(i) + other.nnz(i) != 0.0 ) {
180+ currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
181+ (nnz(i) + other.nnz(i))
182+ }
183+
184+ if (currMax(i) < other.currMax(i)) currMax(i) = other.currMax(i)
185+
186+ if (currMin(i) > other.currMin(i)) currMin(i) = other.currMin(i)
187+
188+ i += 1
189+ }
190+
191+ nnz += other.nnz
192+ this
193+ }
194+ }
195+
196+ /**
33197 * Represents a row-oriented distributed Matrix with no meaningful row indices.
34198 *
35199 * @param rows rows stored as an RDD[Vector]
@@ -240,6 +404,24 @@ class RowMatrix(
240404 }
241405 }
242406
407+ /**
408+ * Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
409+ */
410+ def multiVariateSummaryStatistics (): VectorRDDStatisticalSummary = {
411+ val zeroValue = new VectorRDDStatisticsAggregator (
412+ BDV .zeros[Double ](nCols),
413+ BDV .zeros[Double ](nCols),
414+ 0.0 ,
415+ BDV .zeros[Double ](nCols),
416+ BDV .fill(nCols)(Double .MinValue ),
417+ BDV .fill(nCols)(Double .MaxValue ))
418+
419+ rows.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator ](zeroValue)(
420+ (aggregator, data) => aggregator.add(data),
421+ (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
422+ )
423+ }
424+
243425 /**
244426 * Multiply this matrix by a local matrix on the right.
245427 *
0 commit comments