Skip to content

Commit a6d5a2e

Browse files
committed
rewrite for only computing non-zero elements
1 parent 3980287 commit a6d5a2e

File tree

1 file changed

+26
-27
lines changed

1 file changed

+26
-27
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5050
* with the size of Vector as input parameter.
5151
*/
5252
def summarizeStatistics(size: Int): VectorRDDStatisticalSummary = {
53-
val results = self.map(_.toBreeze).aggregate((
53+
val (fakeMean, fakeM2n, totalCnt, nnz, max, min) = self.map(_.toBreeze).aggregate((
5454
BV.zeros[Double](size),
5555
BV.zeros[Double](size),
5656
0.0,
@@ -59,19 +59,16 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5959
BV.fill(size){Double.MaxValue}))(
6060
seqOp = (c, v) => (c, v) match {
6161
case ((prevMean, prevM2n, cnt, nnzVec, maxVec, minVec), currData) =>
62-
val currMean = prevMean :* (cnt / (cnt + 1.0))
63-
axpy(1.0/(cnt+1.0), currData, currMean)
64-
axpy(-1.0, currData, prevMean)
65-
prevMean :*= (currMean - currData)
66-
axpy(1.0, prevMean, prevM2n)
67-
axpy(1.0,
68-
Vectors.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze,
69-
nnzVec)
70-
currData.activeIterator.foreach { case (id, value) =>
62+
currData.activeIterator.map{ case (id, value) =>
63+
val tmpPrevMean = prevMean(id)
64+
prevMean(id) = (prevMean(id) * cnt + value) / (cnt + 1.0)
7165
if (maxVec(id) < value) maxVec(id) = value
7266
if (minVec(id) > value) minVec(id) = value
67+
nnzVec(id) += 1.0
68+
prevM2n(id) += (value - prevMean(id)) * (value - tmpPrevMean)
7369
}
74-
(currMean,
70+
71+
(prevMean,
7572
prevM2n,
7673
cnt + 1.0,
7774
nnzVec,
@@ -84,32 +81,34 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8481
(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
8582
val totalCnt = cnt1 + cnt2
8683
val deltaMean = mean2 - mean1
87-
mean1 :*= (cnt1 / totalCnt)
88-
axpy(cnt2/totalCnt, mean2, mean1)
89-
val totalMean = mean1
90-
deltaMean :*= deltaMean
91-
axpy(cnt1*cnt2/totalCnt, deltaMean, m2n1)
92-
axpy(1.0, m2n2, m2n1)
93-
val totalM2n = m2n1
84+
val totalMean = ((mean1 :* nnz1) + (mean2 :* nnz2)) :/ (nnz1 + nnz2)
85+
val totalM2n = m2n1 + m2n2 + ((deltaMean :* deltaMean) :* (nnz1 :* nnz2) :/ (nnz1 + nnz2))
9486
max2.activeIterator.foreach { case (id, value) =>
9587
if (max1(id) < value) max1(id) = value
9688
}
9789
min2.activeIterator.foreach { case (id, value) =>
9890
if (min1(id) > value) min1(id) = value
9991
}
100-
axpy(1.0, nnz2, nnz1)
101-
(totalMean, totalM2n, totalCnt, nnz1, max1, min1)
92+
(totalMean, totalM2n, totalCnt, nnz1 + nnz2, max1, min1)
10293
}
10394
)
10495

105-
results._2 :/= results._3
96+
// solve real mean
97+
val realMean = fakeMean :* nnz :/ totalCnt
98+
// solve real variance
99+
val deltaMean = fakeMean :- 0.0
100+
val realVar = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
101+
max :+= 0.0
102+
min :+= 0.0
103+
104+
realVar :/= totalCnt
106105

107106
VectorRDDStatisticalSummary(
108-
Vectors.fromBreeze(results._1),
109-
Vectors.fromBreeze(results._2),
110-
results._3.toLong,
111-
Vectors.fromBreeze(results._4),
112-
Vectors.fromBreeze(results._5),
113-
Vectors.fromBreeze(results._6))
107+
Vectors.fromBreeze(realMean),
108+
Vectors.fromBreeze(realVar),
109+
totalCnt.toLong,
110+
Vectors.fromBreeze(nnz),
111+
Vectors.fromBreeze(max),
112+
Vectors.fromBreeze(min))
114113
}
115114
}

0 commit comments

Comments
 (0)