Skip to content

Commit 84324fb

Browse files
committed
[SPARK-4355][MLLIB] fix OnlineSummarizer.merge when other.mean is zero
See inline comment about the bug. I also did some code clean-up. dbtsai I moved `update` to a private method of `MultivariateOnlineSummarizer`. I don't think it will cause performance regression, but it would be great if you have some time to test. Author: Xiangrui Meng <[email protected]> Closes #3220 from mengxr/SPARK-4355 and squashes the following commits: 5ef601f [Xiangrui Meng] fix OnlineSummarizer.merge when other.mean is zero and some code clean-up
1 parent faeb41d commit 84324fb

File tree

2 files changed

+51
-45
lines changed

2 files changed

+51
-45
lines changed

mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala

Lines changed: 40 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
4949
private var currMax: BDV[Double] = _
5050
private var currMin: BDV[Double] = _
5151

52+
/**
53+
* Adds input value to position i.
54+
*/
55+
private[this] def add(i: Int, value: Double) = {
56+
if (value != 0.0) {
57+
if (currMax(i) < value) {
58+
currMax(i) = value
59+
}
60+
if (currMin(i) > value) {
61+
currMin(i) = value
62+
}
63+
64+
val prevMean = currMean(i)
65+
val diff = value - prevMean
66+
currMean(i) = prevMean + diff / (nnz(i) + 1.0)
67+
currM2n(i) += (value - currMean(i)) * diff
68+
currM2(i) += value * value
69+
currL1(i) += math.abs(value)
70+
71+
nnz(i) += 1.0
72+
}
73+
}
74+
5275
/**
5376
* Add a new sample to this summarizer, and update the statistical summary.
5477
*
@@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
7295
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
7396
s" Expecting $n but got ${sample.size}.")
7497

75-
@inline def update(i: Int, value: Double) = {
76-
if (value != 0.0) {
77-
if (currMax(i) < value) {
78-
currMax(i) = value
79-
}
80-
if (currMin(i) > value) {
81-
currMin(i) = value
82-
}
83-
84-
val tmpPrevMean = currMean(i)
85-
currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
86-
currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
87-
currM2(i) += value * value
88-
currL1(i) += math.abs(value)
89-
90-
nnz(i) += 1.0
91-
}
92-
}
93-
9498
sample match {
9599
case dv: DenseVector => {
96100
var j = 0
97101
while (j < dv.size) {
98-
update(j, dv.values(j))
102+
add(j, dv.values(j))
99103
j += 1
100104
}
101105
}
102106
case sv: SparseVector =>
103107
var j = 0
104108
while (j < sv.indices.size) {
105-
update(sv.indices(j), sv.values(j))
109+
add(sv.indices(j), sv.values(j))
106110
j += 1
107111
}
108112
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
@@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
124128
require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " +
125129
s"Expecting $n but got ${other.n}.")
126130
totalCnt += other.totalCnt
127-
val deltaMean: BDV[Double] = currMean - other.currMean
128131
var i = 0
129132
while (i < n) {
130-
// merge mean together
131-
if (other.currMean(i) != 0.0) {
132-
currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
133-
(nnz(i) + other.nnz(i))
134-
}
135-
// merge m2n together
136-
if (nnz(i) + other.nnz(i) != 0.0) {
137-
currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
138-
(nnz(i) + other.nnz(i))
139-
}
140-
// merge m2 together
141-
if (nnz(i) + other.nnz(i) != 0.0) {
133+
val thisNnz = nnz(i)
134+
val otherNnz = other.nnz(i)
135+
val totalNnz = thisNnz + otherNnz
136+
if (totalNnz != 0.0) {
137+
val deltaMean = other.currMean(i) - currMean(i)
138+
// merge mean together
139+
currMean(i) += deltaMean * otherNnz / totalNnz
140+
// merge m2n together
141+
currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
142+
// merge m2 together
142143
currM2(i) += other.currM2(i)
143-
}
144-
// merge l1 together
145-
if (nnz(i) + other.nnz(i) != 0.0) {
144+
// merge l1 together
146145
currL1(i) += other.currL1(i)
146+
// merge max and min
147+
currMax(i) = math.max(currMax(i), other.currMax(i))
148+
currMin(i) = math.min(currMin(i), other.currMin(i))
147149
}
148-
149-
if (currMax(i) < other.currMax(i)) {
150-
currMax(i) = other.currMax(i)
151-
}
152-
if (currMin(i) > other.currMin(i)) {
153-
currMin(i) = other.currMin(i)
154-
}
150+
nnz(i) = totalNnz
155151
i += 1
156152
}
157-
nnz += other.nnz
158153
} else if (totalCnt == 0 && other.totalCnt != 0) {
159154
this.n = other.n
160155
this.currMean = other.currMean.copy

mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite {
208208

209209
assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch")
210210
}
211+
212+
test("merging summarizer when one side has zero mean (SPARK-4355)") {
213+
val s0 = new MultivariateOnlineSummarizer()
214+
.add(Vectors.dense(2.0))
215+
.add(Vectors.dense(2.0))
216+
val s1 = new MultivariateOnlineSummarizer()
217+
.add(Vectors.dense(1.0))
218+
.add(Vectors.dense(-1.0))
219+
s0.merge(s1)
220+
assert(s0.mean(0) ~== 1.0 absTol 1e-14)
221+
}
211222
}

0 commit comments

Comments
 (0)