Skip to content

Commit 4e4fbd1

Browse files
committed
separate seqop and combop out as independent functions
1 parent a6d5a2e commit 4e4fbd1

File tree

1 file changed

+62
-40
lines changed

1 file changed

+62
-40
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 62 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@ case class VectorRDDStatisticalSummary(
3030
min: Vector,
3131
nonZeroCnt: Vector) extends Serializable
3232

33+
private case class VectorRDDStatisticalRing(
34+
fakeMean: BV[Double],
35+
fakeM2n: BV[Double],
36+
totalCnt: Double,
37+
nnz: BV[Double],
38+
max: BV[Double],
39+
min: BV[Double])
40+
3341
/**
3442
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
3543
* implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use
@@ -49,57 +57,71 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
4957
* }}},
5058
* with the size of Vector as input parameter.
5159
*/
60+
61+
private def seqOp(aggregator: VectorRDDStatisticalRing, currData: BV[Double]): VectorRDDStatisticalRing = {
62+
aggregator match {
63+
case VectorRDDStatisticalRing(prevMean, prevM2n, cnt, nnzVec, maxVec, minVec) =>
64+
currData.activeIterator.foreach {
65+
case (id, value) =>
66+
if (maxVec(id) < value) maxVec(id) = value
67+
if (minVec(id) > value) minVec(id) = value
68+
69+
val tmpPrevMean = prevMean(id)
70+
prevMean(id) = (prevMean(id) * cnt + value) / (cnt + 1.0)
71+
prevM2n(id) += (value - prevMean(id)) * (value - tmpPrevMean)
72+
73+
nnzVec(id) += 1.0
74+
}
75+
76+
VectorRDDStatisticalRing(prevMean,
77+
prevM2n,
78+
cnt + 1.0,
79+
nnzVec,
80+
maxVec,
81+
minVec)
82+
}
83+
}
84+
85+
private def combOp(statistics1: VectorRDDStatisticalRing, statistics2: VectorRDDStatisticalRing): VectorRDDStatisticalRing = {
86+
(statistics1, statistics2) match {
87+
case (VectorRDDStatisticalRing(mean1, m2n1, cnt1, nnz1, max1, min1),
88+
VectorRDDStatisticalRing(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
89+
val totalCnt = cnt1 + cnt2
90+
val deltaMean = mean2 - mean1
91+
val totalMean = ((mean1 :* nnz1) + (mean2 :* nnz2)) :/ (nnz1 + nnz2)
92+
val totalM2n = m2n1 + m2n2 + ((deltaMean :* deltaMean) :* (nnz1 :* nnz2) :/ (nnz1 + nnz2))
93+
max2.activeIterator.foreach {
94+
case (id, value) =>
95+
if (max1(id) < value) max1(id) = value
96+
}
97+
min2.activeIterator.foreach {
98+
case (id, value) =>
99+
if (min1(id) > value) min1(id) = value
100+
}
101+
VectorRDDStatisticalRing(totalMean, totalM2n, totalCnt, nnz1 + nnz2, max1, min1)
102+
}
103+
}
104+
52105
def summarizeStatistics(size: Int): VectorRDDStatisticalSummary = {
53-
val (fakeMean, fakeM2n, totalCnt, nnz, max, min) = self.map(_.toBreeze).aggregate((
106+
val zeroValue = VectorRDDStatisticalRing(
54107
BV.zeros[Double](size),
55108
BV.zeros[Double](size),
56109
0.0,
57110
BV.zeros[Double](size),
58-
BV.fill(size){Double.MinValue},
59-
BV.fill(size){Double.MaxValue}))(
60-
seqOp = (c, v) => (c, v) match {
61-
case ((prevMean, prevM2n, cnt, nnzVec, maxVec, minVec), currData) =>
62-
currData.activeIterator.map{ case (id, value) =>
63-
val tmpPrevMean = prevMean(id)
64-
prevMean(id) = (prevMean(id) * cnt + value) / (cnt + 1.0)
65-
if (maxVec(id) < value) maxVec(id) = value
66-
if (minVec(id) > value) minVec(id) = value
67-
nnzVec(id) += 1.0
68-
prevM2n(id) += (value - prevMean(id)) * (value - tmpPrevMean)
69-
}
70-
71-
(prevMean,
72-
prevM2n,
73-
cnt + 1.0,
74-
nnzVec,
75-
maxVec,
76-
minVec)
77-
},
78-
combOp = (c, v) => (c, v) match {
79-
case (
80-
(mean1, m2n1, cnt1, nnz1, max1, min1),
81-
(mean2, m2n2, cnt2, nnz2, max2, min2)) =>
82-
val totalCnt = cnt1 + cnt2
83-
val deltaMean = mean2 - mean1
84-
val totalMean = ((mean1 :* nnz1) + (mean2 :* nnz2)) :/ (nnz1 + nnz2)
85-
val totalM2n = m2n1 + m2n2 + ((deltaMean :* deltaMean) :* (nnz1 :* nnz2) :/ (nnz1 + nnz2))
86-
max2.activeIterator.foreach { case (id, value) =>
87-
if (max1(id) < value) max1(id) = value
88-
}
89-
min2.activeIterator.foreach { case (id, value) =>
90-
if (min1(id) > value) min1(id) = value
91-
}
92-
(totalMean, totalM2n, totalCnt, nnz1 + nnz2, max1, min1)
93-
}
94-
)
111+
BV.fill(size)(Double.MinValue),
112+
BV.fill(size)(Double.MaxValue))
113+
114+
val breezeVectors = self.collect().map(_.toBreeze)
115+
val VectorRDDStatisticalRing(fakeMean, fakeM2n, totalCnt, nnz, max, min) = breezeVectors.aggregate(zeroValue)(seqOp, combOp)
95116

96117
// solve real mean
97118
val realMean = fakeMean :* nnz :/ totalCnt
98119
// solve real variance
99120
val deltaMean = fakeMean :- 0.0
100121
val realVar = fakeM2n - ((deltaMean :* deltaMean) :* (nnz :* (nnz :- totalCnt)) :/ totalCnt)
101-
max :+= 0.0
102-
min :+= 0.0
122+
// max, min process, in case of a column is all zero.
123+
// max :+= 0.0
124+
// min :+= 0.0
103125

104126
realVar :/= totalCnt
105127

0 commit comments

Comments
 (0)