@@ -63,8 +63,8 @@ private class ColumnStatisticsAggregator(private val n: Int)
6363
6464 val denominator = totalCnt - 1.0
6565
66- // Sample variance is computed, if the denominator is 0, the variance is just 0.
67- if (denominator != 0.0 ) {
66+ // Sample variance is computed, if the denominator is less than 0, the variance is just 0.
67+ if (denominator > 0.0 ) {
6868 val deltaMean = currMean
6969 var i = 0
7070 while (i < currM2n.size) {
@@ -107,8 +107,12 @@ private class ColumnStatisticsAggregator(private val n: Int)
107107 currData.activeIterator.foreach {
108108 case (_, 0.0 ) => // Skip explicit zero elements.
109109 case (i, value) =>
110- if (currMax(i) < value) currMax(i) = value
111- if (currMin(i) > value) currMin(i) = value
110+ if (currMax(i) < value) {
111+ currMax(i) = value
112+ }
113+ if (currMin(i) > value) {
114+ currMin(i) = value
115+ }
112116
113117 val tmpPrevMean = currMean(i)
114118 currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0 )
@@ -125,11 +129,9 @@ private class ColumnStatisticsAggregator(private val n: Int)
125129 * Merges another aggregator.
126130 */
127131 def merge (other : ColumnStatisticsAggregator ): this .type = {
128-
129132 require(n == other.n, s " Dimensions mismatch. Expecting $n but got ${other.n}. " )
130133
131134 totalCnt += other.totalCnt
132-
133135 val deltaMean = currMean - other.currMean
134136
135137 var i = 0
@@ -139,22 +141,21 @@ private class ColumnStatisticsAggregator(private val n: Int)
139141 currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
140142 (nnz(i) + other.nnz(i))
141143 }
142-
143144 // merge m2n together
144145 if (nnz(i) + other.nnz(i) != 0.0 ) {
145146 currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) /
146147 (nnz(i) + other.nnz(i))
147148 }
148-
149- if (currMax(i) < other.currMax(i)) currMax(i) = other.currMax(i)
150-
151- if (currMin(i) > other.currMin(i)) currMin(i) = other.currMin(i)
152-
149+ if (currMax(i) < other.currMax(i)) {
150+ currMax(i) = other.currMax(i)
151+ }
152+ if (currMin(i) > other.currMin(i)) {
153+ currMin(i) = other.currMin(i)
154+ }
153155 i += 1
154156 }
155157
156158 nnz += other.nnz
157-
158159 this
159160 }
160161}
@@ -414,17 +415,6 @@ class RowMatrix(
414415 mat
415416 }
416417
417- /** Updates or verifies the number of columns. */
418- private def updateNumCols (n : Int ) {
419- if (nCols <= 0 ) {
420- nCols == n
421- } else {
422- require(nCols == n,
423- s " The number of columns $n is different from " +
424- s " what specified or previously computed: ${nCols}. " )
425- }
426- }
427-
428418 /** Updates or verfires the number of rows. */
429419 private def updateNumRows (m : Long ) {
430420 if (nRows <= 0 ) {
0 commit comments