Skip to content

Commit 9a75ebd

Browse files
committed
add case class to wrap return values
1 parent d816ac7 commit 9a75ebd

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ import breeze.linalg.{Vector => BV}
2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.rdd.RDD
2323

24+
case class VectorRDDStatisticalSummary(
25+
mean: Vector,
26+
variance: Vector,
27+
count: Long,
28+
max: Vector,
29+
min: Vector,
30+
nonZeroCnt: Vector) extends Serializable
31+
2432
/**
2533
* Extra functions available on RDDs of [[org.apache.spark.mllib.linalg.Vector Vector]] through an
2634
* implicit conversion. Import `org.apache.spark.MLContext._` at the top of your program to use
@@ -40,7 +48,7 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
4048
* }}},
4149
* with the size of Vector as input parameter.
4250
*/
43-
def statistics(size: Int): (Vector, Vector, Double, Vector, Vector, Vector) = {
51+
def summarizeStatistics(size: Int): VectorRDDStatisticalSummary = {
4452
val results = self.map(_.toBreeze).aggregate((
4553
BV.zeros[Double](size),
4654
BV.zeros[Double](size),
@@ -83,9 +91,10 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
8391
}
8492
)
8593

86-
(Vectors.fromBreeze(results._1),
94+
VectorRDDStatisticalSummary(
95+
Vectors.fromBreeze(results._1),
8796
Vectors.fromBreeze(results._2 :/ results._3),
88-
results._3,
97+
results._3.toLong,
8998
Vectors.fromBreeze(results._4),
9099
Vectors.fromBreeze(results._5),
91100
Vectors.fromBreeze(results._6))

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3333

3434
test("full-statistics") {
3535
val data = sc.parallelize(localData, 2)
36-
val (mean, variance, cnt, nnz, max, min) = data.statistics(3)
36+
val VectorRDDStatisticalSummary(mean, variance, cnt, nnz, max, min) = data.summarizeStatistics(3)
3737
assert(equivVector(mean, Vectors.dense(4.0, 5.0, 6.0)), "Column mean do not match.")
3838
assert(equivVector(variance, Vectors.dense(6.0, 6.0, 6.0)), "Column variance do not match.")
3939
assert(cnt === 3, "Column cnt do not match.")

0 commit comments

Comments
 (0)