Skip to content

Commit 4eaf28a

Browse files
committed
merge VectorRDDStatistics into RowMatrix
1 parent 48ee053 commit 4eaf28a

File tree

5 files changed

+229
-306
lines changed

5 files changed

+229
-306
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 184 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg.distributed
1919

2020
import java.util
2121

22-
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
22+
import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
2323
import breeze.numerics.{sqrt => brzSqrt}
2424
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2525

@@ -29,7 +29,171 @@ import org.apache.spark.rdd.RDD
2929
import org.apache.spark.Logging
3030

3131
/**
32-
* :: Experimental ::
32+
* Trait of the summary statistics, including mean, variance, count, max, min, and non-zero elements
33+
* count.
34+
*/
35+
trait VectorRDDStatisticalSummary {
36+
37+
/**
38+
* Computes the mean of columns in RDD[Vector].
39+
*/
40+
def mean: Vector
41+
42+
/**
43+
* Computes the sample variance of columns in RDD[Vector].
44+
*/
45+
def variance: Vector
46+
47+
/**
48+
* Computes number of vectors in RDD[Vector].
49+
*/
50+
def count: Long
51+
52+
/**
53+
* Computes the number of non-zero elements in each column of RDD[Vector].
54+
*/
55+
def numNonZeros: Vector
56+
57+
/**
58+
* Computes the maximum of each column in RDD[Vector].
59+
*/
60+
def max: Vector
61+
62+
/**
63+
* Computes the minimum of each column in RDD[Vector].
64+
*/
65+
def min: Vector
66+
}
67+
68+
69+
/**
70+
* Aggregates [[org.apache.spark.mllib.linalg.distributed.VectorRDDStatisticalSummary
71+
* VectorRDDStatisticalSummary]] together with add() and merge() function. Online variance solution
72+
* used in add() function, while parallel variance solution used in merge() function. Reference here
73+
* : [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. Solution
74+
* here ignoring the zero elements when calling add() and merge(), for decreasing the O(n) algorithm
75+
* to O(nnz). Real variance is computed here after we get other statistics, simply by another
76+
* parallel combination process.
77+
*/
78+
private class VectorRDDStatisticsAggregator(
79+
val currMean: BDV[Double],
80+
val currM2n: BDV[Double],
81+
var totalCnt: Double,
82+
val nnz: BDV[Double],
83+
val currMax: BDV[Double],
84+
val currMin: BDV[Double])
85+
extends VectorRDDStatisticalSummary with Serializable {
86+
87+
override def mean = {
88+
val realMean = BDV.zeros[Double](currMean.length)
89+
var i = 0
90+
while (i < currMean.length) {
91+
realMean(i) = currMean(i) * nnz(i) / totalCnt
92+
i += 1
93+
}
94+
Vectors.fromBreeze(realMean)
95+
}
96+
97+
override def variance = {
98+
val realVariance = BDV.zeros[Double](currM2n.length)
99+
100+
val denominator = totalCnt - 1.0
101+
102+
// Sample variance is computed, if the denominator is 0, the variance is just 0.
103+
if (denominator != 0.0) {
104+
val deltaMean = currMean
105+
var i = 0
106+
while (i < currM2n.size) {
107+
realVariance(i) =
108+
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
109+
realVariance(i) /= denominator
110+
i += 1
111+
}
112+
}
113+
114+
Vectors.fromBreeze(realVariance)
115+
}
116+
117+
override def count: Long = totalCnt.toLong
118+
119+
override def numNonZeros: Vector = Vectors.fromBreeze(nnz)
120+
121+
override def max: Vector = {
122+
var i = 0
123+
while (i < nnz.length) {
124+
if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
125+
i += 1
126+
}
127+
Vectors.fromBreeze(currMax)
128+
}
129+
130+
override def min: Vector = {
131+
var i = 0
132+
while (i < nnz.length) {
133+
if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
134+
i += 1
135+
}
136+
Vectors.fromBreeze(currMin)
137+
}
138+
139+
/**
140+
* Aggregate function used for aggregating elements in a worker together.
141+
*/
142+
def add(currData: BV[Double]): this.type = {
143+
currData.activeIterator.foreach {
144+
// this case is used for filtering the zero elements if the vector.
145+
case (id, 0.0) =>
146+
case (id, value) =>
147+
if (currMax(id) < value) currMax(id) = value
148+
if (currMin(id) > value) currMin(id) = value
149+
150+
val tmpPrevMean = currMean(id)
151+
currMean(id) = (currMean(id) * nnz(id) + value) / (nnz(id) + 1.0)
152+
currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
153+
154+
nnz(id) += 1.0
155+
}
156+
157+
totalCnt += 1.0
158+
this
159+
}
160+
161+
/**
162+
* Combine function used for combining intermediate results together from every worker.
163+
*/
164+
def merge(other: VectorRDDStatisticsAggregator): this.type = {
165+
166+
totalCnt += other.totalCnt
167+
168+
val deltaMean = currMean - other.currMean
169+
170+
var i = 0
171+
while (i < other.currMean.length) {
172+
// merge mean together
173+
if (other.currMean(i) != 0.0) {
174+
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
175+
(nnz(i) + other.nnz(i))
176+
}
177+
178+
// merge m2n together
179+
if (nnz(i) + other.nnz(i) != 0.0) {
180+
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
181+
(nnz(i) + other.nnz(i))
182+
}
183+
184+
if (currMax(i) < other.currMax(i)) currMax(i) = other.currMax(i)
185+
186+
if (currMin(i) > other.currMin(i)) currMin(i) = other.currMin(i)
187+
188+
i += 1
189+
}
190+
191+
nnz += other.nnz
192+
this
193+
}
194+
}
195+
196+
/**
33197
* Represents a row-oriented distributed Matrix with no meaningful row indices.
34198
*
35199
* @param rows rows stored as an RDD[Vector]
@@ -240,6 +404,24 @@ class RowMatrix(
240404
}
241405
}
242406

407+
/**
408+
* Compute full column-wise statistics for the RDD with the size of Vector as input parameter.
409+
*/
410+
def multiVariateSummaryStatistics(): VectorRDDStatisticalSummary = {
411+
val zeroValue = new VectorRDDStatisticsAggregator(
412+
BDV.zeros[Double](nCols),
413+
BDV.zeros[Double](nCols),
414+
0.0,
415+
BDV.zeros[Double](nCols),
416+
BDV.fill(nCols)(Double.MinValue),
417+
BDV.fill(nCols)(Double.MaxValue))
418+
419+
rows.map(_.toBreeze).aggregate[VectorRDDStatisticsAggregator](zeroValue)(
420+
(aggregator, data) => aggregator.add(data),
421+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2)
422+
)
423+
}
424+
243425
/**
244426
* Multiply this matrix by a local matrix on the right.
245427
*

0 commit comments

Comments
 (0)