Skip to content

Commit 1338ea1

Browse files
committed
all-in-one version test passed
1 parent cc65810 commit 1338ea1

File tree

2 files changed

+24
-9
lines changed

2 files changed

+24
-9
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import breeze.linalg.{Vector => BV, DenseVector => BDV}
2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.mllib.util.MLUtils._
2323
import org.apache.spark.rdd.RDD
24-
import breeze.numerics._
24+
import breeze.linalg._
2525

2626
/**
2727
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
@@ -163,23 +163,34 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
163163
}
164164
}
165165

166-
def parallelMeanAndVar(size: Int): (Vector, Vector) = {
167-
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0))(
166+
def parallelMeanAndVar(size: Int): (Vector, Vector, Double, Vector, Vector, Vector) = {
167+
val statistics = self.map(_.toBreeze).aggregate((BV.zeros[Double](size), BV.zeros[Double](size), 0.0, BV.zeros[Double](size), BV.fill(size){Double.MinValue}, BV.fill(size){Double.MaxValue}))(
168168
seqOp = (c, v) => (c, v) match {
169-
case ((prevMean, prevM2n, cnt), currData) =>
169+
case ((prevMean, prevM2n, cnt, nnz, maxVec, minVec), currData) =>
170170
val currMean = ((prevMean :* cnt) + currData) :/ (cnt + 1.0)
171-
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0)
171+
val nonZeroCnt = Vectors.sparse(size, currData.activeKeysIterator.toSeq.map(x => (x, 1.0))).toBreeze
172+
currData.activeIterator.foreach { case (id, value) =>
173+
if (maxVec(id) < value) maxVec(id) = value
174+
if (minVec(id) > value) minVec(id) = value
175+
}
176+
(currMean, prevM2n + ((currData - prevMean) :* (currData - currMean)), cnt + 1.0, nnz + nonZeroCnt, maxVec, minVec)
172177
},
173178
combOp = (lhs, rhs) => (lhs, rhs) match {
174-
case ((lhsMean, lhsM2n, lhsCnt), (rhsMean, rhsM2n, rhsCnt)) =>
179+
case ((lhsMean, lhsM2n, lhsCnt, lhsNNZ, lhsMax, lhsMin), (rhsMean, rhsM2n, rhsCnt, rhsNNZ, rhsMax, rhsMin)) =>
175180
val totalCnt = lhsCnt + rhsCnt
176181
val totalMean = (lhsMean :* lhsCnt) + (rhsMean :* rhsCnt) :/ totalCnt
177182
val deltaMean = rhsMean - lhsMean
178183
val totalM2n = lhsM2n + rhsM2n + (((deltaMean :* deltaMean) :* (lhsCnt * rhsCnt)) :/ totalCnt)
179-
(totalMean, totalM2n, totalCnt)
184+
rhsMax.activeIterator.foreach { case (id, value) =>
185+
if (lhsMax(id) < value) lhsMax(id) = value
186+
}
187+
rhsMin.activeIterator.foreach { case (id, value) =>
188+
if (lhsMin(id) > value) lhsMin(id) = value
189+
}
190+
(totalMean, totalM2n, totalCnt, lhsNNZ + rhsNNZ, lhsMax, lhsMin)
180191
}
181192
)
182193

183-
(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3))
194+
(Vectors.fromBreeze(statistics._1), Vectors.fromBreeze(statistics._2 :/ statistics._3), statistics._3, Vectors.fromBreeze(statistics._4), Vectors.fromBreeze(statistics._5), Vectors.fromBreeze(statistics._6))
184195
}
185196
}

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,13 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
132132

133133
test("meanAndVar") {
134134
val data = sc.parallelize(localData, 2)
135-
val (mean, sd) = data.parallelMeanAndVar(3)
135+
val (mean, sd, cnt, nnz, max, min) = data.parallelMeanAndVar(3)
136136
assert(equivVector(mean, Vectors.dense(colMeans)), "Column means do not match.")
137137
assert(equivVector(sd, Vectors.dense(colVar)), "Column SD do not match.")
138+
assert(cnt === 3, "Column cnt do not match.")
139+
assert(equivVector(nnz, Vectors.dense(3.0, 3.0, 3.0)), "Column nnz do not match.")
140+
assert(equivVector(max, Vectors.dense(7.0, 8.0, 9.0)), "Column max do not match.")
141+
assert(equivVector(min, Vectors.dense(1.0, 2.0, 3.0)), "Column min do not match.")
138142
}
139143
}
140144

0 commit comments

Comments
 (0)