1616 */
1717package org .apache .spark .mllib .rdd
1818
19- import breeze .linalg .{Vector => BV }
19+ import breeze .linalg .{Vector => BV , axpy }
2020
2121import org .apache .spark .mllib .linalg .{Vector , Vectors }
2222import org .apache .spark .rdd .RDD
@@ -92,8 +92,14 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
9292 VectorRDDStatisticalRing (mean2, m2n2, cnt2, nnz2, max2, min2)) =>
9393 val totalCnt = cnt1 + cnt2
9494 val deltaMean = mean2 - mean1
95- val totalMean = ((mean1 :* nnz1) + (mean2 :* nnz2)) :/ (nnz1 + nnz2)
96- val totalM2n = m2n1 + m2n2 + ((deltaMean :* deltaMean) :* (nnz1 :* nnz2) :/ (nnz1 + nnz2))
95+ mean2.activeIterator.foreach {
96+ case (id, 0.0 ) =>
97+ case (id, value) => mean1(id) = (mean1(id) * nnz1(id) + mean2(id) * nnz2(id)) / (nnz1(id) + nnz2(id))
98+ }
99+ m2n2.activeIterator.foreach {
100+ case (id, 0.0 ) =>
101+ case (id, value) => m2n1(id) += value + deltaMean(id) * deltaMean(id) * nnz1(id) * nnz2(id) / (nnz1(id)+ nnz2(id))
102+ }
97103 max2.activeIterator.foreach {
98104 case (id, value) =>
99105 if (max1(id) < value) max1(id) = value
@@ -102,7 +108,8 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
102108 case (id, value) =>
103109 if (min1(id) > value) min1(id) = value
104110 }
105- VectorRDDStatisticalRing (totalMean, totalM2n, totalCnt, nnz1 + nnz2, max1, min1)
111+ axpy(1.0 , nnz2, nnz1)
112+ VectorRDDStatisticalRing (mean1, m2n1, totalCnt, nnz1, max1, min1)
106113 }
107114 }
108115
0 commit comments