@@ -21,7 +21,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV}
2121import org .apache .spark .mllib .linalg .{Vector , Vectors }
2222import org .apache .spark .mllib .util .MLUtils ._
2323import org .apache .spark .rdd .RDD
24- import breeze .numerics ._
24+ import breeze .linalg ._
2525
2626/**
2727 * Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector ]] through an
@@ -163,23 +163,34 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
163163 }
164164 }
165165
166- def parallelMeanAndVar (size : Int ): (Vector , Vector ) = {
167- val statistics = self.map(_.toBreeze).aggregate((BV .zeros[Double ](size), BV .zeros[Double ](size), 0.0 ))(
166+ def parallelMeanAndVar (size : Int ): (Vector , Vector , Double , Vector , Vector , Vector ) = {
167+ val statistics = self.map(_.toBreeze).aggregate((BV .zeros[Double ](size), BV .zeros[Double ](size), 0.0 , BV .zeros[ Double ](size), BV .fill(size){ Double . MinValue }, BV .fill(size){ Double . MaxValue } ))(
168168 seqOp = (c, v) => (c, v) match {
169- case ((prevMean, prevM2n, cnt), currData) =>
169+ case ((prevMean, prevM2n, cnt, nnz, maxVec, minVec ), currData) =>
170170 val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0 )
171- (currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0 )
171+ val nonZeroCnt = Vectors .sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0 ))).toBreeze
172+ currData.activeIterator.foreach { case (id, value) =>
173+ if (maxVec(id) < value) maxVec(id) = value
174+ if (minVec(id) > value) minVec(id) = value
175+ }
176+ (currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0 , nnz + nonZeroCnt, maxVec, minVec)
172177 },
173178 combOp = (lhs, rhs) => (lhs, rhs) match {
174- case ((lhsMean, lhsM2n, lhsCnt), (rhsMean, rhsM2n, rhsCnt)) =>
179+ case ((lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin ), (rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin )) =>
175180 val totalCnt = lhsCnt + rhsCnt
176181 val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
177182 val deltaMean = rhsMean - lhsMean
178183 val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
179- (totalMean, totalM2n, totalCnt)
184+ rhsMax.activeIterator.foreach { case (id, value) =>
185+ if (lhsMax(id) < value) lhsMax(id) = value
186+ }
187+ rhsMin.activeIterator.foreach { case (id, value) =>
188+ if (lhsMin(id) > value) lhsMin(id) = value
189+ }
190+ (totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
180191 }
181192 )
182193
183- (Vectors .fromBreeze(statistics._1), Vectors .fromBreeze(statistics._2 :/ statistics._3))
194+ (Vectors .fromBreeze(statistics._1), Vectors .fromBreeze(statistics._2 :/ statistics._3), statistics._3, Vectors .fromBreeze(statistics._4), Vectors .fromBreeze(statistics._5), Vectors .fromBreeze(statistics._6) )
184195 }
185196}
0 commit comments