1717
1818package org .apache .spark .mllib .rdd
1919
20- import org .apache .spark .mllib .linalg .Vector
2120import org .scalatest .FunSuite
22-
23- import org .apache .spark .mllib .linalg .Vectors
24- import org .apache .spark .mllib .util .MLUtils ._
25- import VectorRDDFunctionsSuite ._
21+ import org .apache .spark .mllib .linalg .{Vector , Vectors }
2622import org .apache .spark .mllib .util .LocalSparkContext
23+ import org .apache .spark .mllib .util .MLUtils ._
2724
2825class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
26+ import VectorRDDFunctionsSuite ._
2927
3028 val localData = Array (
31- Vectors .dense(1.0 , 2.0 , 3.0 ),
32- Vectors .dense(4.0 , 5.0 , 6.0 ),
33- Vectors .dense(7.0 , 8.0 , 9.0 )
34- )
29+ Vectors .dense(1.0 , 2.0 , 3.0 ),
30+ Vectors .dense(4.0 , 5.0 , 6.0 ),
31+ Vectors .dense(7.0 , 8.0 , 9.0 )
32+ )
3533
3634 val rowMeans = Array (2.0 , 5.0 , 8.0 )
3735 val rowNorm2 = Array (math.sqrt(14.0 ), math.sqrt(77.0 ), math.sqrt(194.0 ))
@@ -44,6 +42,23 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
4442 val maxVec = Array (7.0 , 8.0 , 9.0 )
4543 val minVec = Array (1.0 , 2.0 , 3.0 )
4644
45+ val shrinkingData = Array (
46+ Vectors .dense(1.0 , 2.0 , 0.0 ),
47+ Vectors .dense(0.0 , 0.0 , 0.0 ),
48+ Vectors .dense(7.0 , 8.0 , 0.0 )
49+ )
50+
51+ val rowShrinkData = Array (
52+ Vectors .dense(1.0 , 2.0 , 0.0 ),
53+ Vectors .dense(7.0 , 8.0 , 0.0 )
54+ )
55+
56+ val colShrinkData = Array (
57+ Vectors .dense(1.0 , 2.0 ),
58+ Vectors .dense(0.0 , 0.0 ),
59+ Vectors .dense(7.0 , 8.0 )
60+ )
61+
4762 test(" rowMeans" ) {
4863 val data = sc.parallelize(localData, 2 )
4964 assert(equivVector(Vectors .dense(data.rowMeans().collect()), Vectors .dense(rowMeans)), " Row means do not match." )
@@ -91,6 +106,22 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
91106 " Optional minimum does not match."
92107 )
93108 }
109+
110+ test(" rowShrink" ) {
111+ val data = sc.parallelize(shrinkingData, 2 )
112+ val res = data.rowShrink().collect()
113+ rowShrinkData.zip(res).foreach { case (lhs, rhs) =>
114+ assert(equivVector(lhs, rhs), " Row shrink error." )
115+ }
116+ }
117+
118+ test(" columnShrink" ) {
119+ val data = sc.parallelize(shrinkingData, 2 )
120+ val res = data.colShrink().collect()
121+ colShrinkData.zip(res).foreach { case (lhs, rhs) =>
122+ assert(equivVector(lhs, rhs), " Column shrink error." )
123+ }
124+ }
94125}
95126
96127object VectorRDDFunctionsSuite {
0 commit comments