@@ -49,6 +49,29 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
4949 private var currMax : BDV [Double ] = _
5050 private var currMin : BDV [Double ] = _
5151
52+ /**
53+ * Adds input value to position i.
54+ */
55+ private [this ] def add (i : Int , value : Double ) = {
56+ if (value != 0.0 ) {
57+ if (currMax(i) < value) {
58+ currMax(i) = value
59+ }
60+ if (currMin(i) > value) {
61+ currMin(i) = value
62+ }
63+
64+ val prevMean = currMean(i)
65+ val diff = value - prevMean
66+ currMean(i) = prevMean + diff / (nnz(i) + 1.0 )
67+ currM2n(i) += (value - currMean(i)) * diff
68+ currM2(i) += value * value
69+ currL1(i) += math.abs(value)
70+
71+ nnz(i) += 1.0
72+ }
73+ }
74+
5275 /**
5376 * Add a new sample to this summarizer, and update the statistical summary.
5477 *
@@ -72,37 +95,18 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
7295 require(n == sample.size, s " Dimensions mismatch when adding new sample. " +
7396 s " Expecting $n but got ${sample.size}. " )
7497
75- @ inline def update (i : Int , value : Double ) = {
76- if (value != 0.0 ) {
77- if (currMax(i) < value) {
78- currMax(i) = value
79- }
80- if (currMin(i) > value) {
81- currMin(i) = value
82- }
83-
84- val tmpPrevMean = currMean(i)
85- currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0 )
86- currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
87- currM2(i) += value * value
88- currL1(i) += math.abs(value)
89-
90- nnz(i) += 1.0
91- }
92- }
93-
9498 sample match {
9599 case dv : DenseVector => {
96100 var j = 0
97101 while (j < dv.size) {
98- update (j, dv.values(j))
102+ add (j, dv.values(j))
99103 j += 1
100104 }
101105 }
102106 case sv : SparseVector =>
103107 var j = 0
104108 while (j < sv.indices.size) {
105- update (sv.indices(j), sv.values(j))
109+ add (sv.indices(j), sv.values(j))
106110 j += 1
107111 }
108112 case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
@@ -124,37 +128,28 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
124128 require(n == other.n, s " Dimensions mismatch when merging with another summarizer. " +
125129 s " Expecting $n but got ${other.n}. " )
126130 totalCnt += other.totalCnt
127- val deltaMean : BDV [Double ] = currMean - other.currMean
128131 var i = 0
129132 while (i < n) {
130- // merge mean together
131- if (other.currMean(i) != 0.0 ) {
132- currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
133- (nnz(i) + other.nnz(i))
134- }
135- // merge m2n together
136- if (nnz(i) + other.nnz(i) != 0.0 ) {
137- currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
138- (nnz(i) + other.nnz(i))
139- }
140- // merge m2 together
141- if (nnz(i) + other.nnz(i) != 0.0 ) {
133+ val thisNnz = nnz(i)
134+ val otherNnz = other.nnz(i)
135+ val totalNnz = thisNnz + otherNnz
136+ if (totalNnz != 0.0 ) {
137+ val deltaMean = other.currMean(i) - currMean(i)
138+ // merge mean together
139+ currMean(i) += deltaMean * otherNnz / totalNnz
140+ // merge m2n together
141+ currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz
142+ // merge m2 together
142143 currM2(i) += other.currM2(i)
143- }
144- // merge l1 together
145- if (nnz(i) + other.nnz(i) != 0.0 ) {
144+ // merge l1 together
146145 currL1(i) += other.currL1(i)
146+ // merge max and min
147+ currMax(i) = math.max(currMax(i), other.currMax(i))
148+ currMin(i) = math.min(currMin(i), other.currMin(i))
147149 }
148-
149- if (currMax(i) < other.currMax(i)) {
150- currMax(i) = other.currMax(i)
151- }
152- if (currMin(i) > other.currMin(i)) {
153- currMin(i) = other.currMin(i)
154- }
150+ nnz(i) = totalNnz
155151 i += 1
156152 }
157- nnz += other.nnz
158153 } else if (totalCnt == 0 && other.totalCnt != 0 ) {
159154 this .n = other.n
160155 this .currMean = other.currMean.copy
0 commit comments