Skip to content

Commit 48ee053

Browse files
committed
fix minor error
1 parent e624f93 commit 48ee053

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,35 @@ import org.apache.spark.rdd.RDD
2727
* count.
2828
*/
2929
trait VectorRDDStatisticalSummary {
30+
31+
/**
32+
* Computes the mean of columns in RDD[Vector].
33+
*/
3034
def mean: Vector
35+
36+
/**
37+
* Computes the sample variance of columns in RDD[Vector].
38+
*/
3139
def variance: Vector
40+
41+
/**
42+
* Computes number of vectors in RDD[Vector].
43+
*/
3244
def count: Long
45+
46+
/**
47+
* Computes the number of non-zero elements in each column of RDD[Vector].
48+
*/
3349
def numNonZeros: Vector
50+
51+
/**
52+
* Computes the maximum of each column in RDD[Vector].
53+
*/
3454
def max: Vector
55+
56+
/**
57+
* Computes the minimum of each column in RDD[Vector].
58+
*/
3559
def min: Vector
3660
}
3761

@@ -53,7 +77,6 @@ private class VectorRDDStatisticsAggregator(
5377
val currMin: BDV[Double])
5478
extends VectorRDDStatisticalSummary with Serializable {
5579

56-
// lazy val is used for computing only once time. Same below.
5780
override def mean = {
5881
val realMean = BDV.zeros[Double](currMean.length)
5982
var i = 0
@@ -71,7 +94,7 @@ private class VectorRDDStatisticsAggregator(
7194
while (i < currM2n.size) {
7295
realVariance(i) =
7396
currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
74-
realVariance(i) /= totalCnt
97+
realVariance(i) /= (totalCnt - 1.0)
7598
i += 1
7699
}
77100
Vectors.fromBreeze(realVariance)

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3434

3535
val localData = Array(
3636
Vectors.dense(1.0, 2.0, 3.0),
37-
Vectors.dense(4.0, 5.0, 6.0),
38-
Vectors.dense(7.0, 8.0, 9.0)
37+
Vectors.dense(4.0, 0.0, 6.0),
38+
Vectors.dense(0.0, 8.0, 9.0)
3939
)
4040

4141
val sparseData = ArrayBuffer(Vectors.sparse(3, Seq((0, 1.0))))
@@ -47,21 +47,21 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4747
val data = sc.parallelize(localData, 2)
4848
val summary = data.computeSummaryStatistics()
4949

50-
assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)),
50+
assert(equivVector(summary.mean, Vectors.dense(5.0 / 3.0, 10.0 / 3.0, 6.0)),
5151
"Dense column mean do not match.")
5252

53-
assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)),
53+
assert(equivVector(summary.variance, Vectors.dense(4.333333333333334, 17.333333333333336, 9.0)),
5454
"Dense column variance do not match.")
5555

5656
assert(summary.count === 3, "Dense column cnt do not match.")
5757

58-
assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)),
58+
assert(equivVector(summary.numNonZeros, Vectors.dense(2.0, 2.0, 3.0)),
5959
"Dense column nnz do not match.")
6060

61-
assert(equivVector(summary.max, Vectors.dense(7.0, 8.0, 9.0)),
61+
assert(equivVector(summary.max, Vectors.dense(4.0, 8.0, 9.0)),
6262
"Dense column max do not match.")
6363

64-
assert(equivVector(summary.min, Vectors.dense(1.0, 2.0, 3.0)),
64+
assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 3.0)),
6565
"Dense column min do not match.")
6666
}
6767

@@ -72,7 +72,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
7272
assert(equivVector(summary.mean, Vectors.dense(0.06, 0.05, 0.0)),
7373
"Sparse column mean do not match.")
7474

75-
assert(equivVector(summary.variance, Vectors.dense(0.2564, 0.2475, 0.0)),
75+
assert(equivVector(summary.variance, Vectors.dense(0.258989898989899, 0.25, 0.0)),
7676
"Sparse column variance do not match.")
7777

7878
assert(summary.count === 100, "Sparse column cnt do not match.")
@@ -90,6 +90,6 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
9090

9191
object VectorRDDFunctionsSuite {
9292
def equivVector(lhs: Vector, rhs: Vector): Boolean = {
93-
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9
93+
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-5
9494
}
9595
}

0 commit comments

Comments
 (0)