Skip to content

Commit 54b19ab

Browse files
committed
add new API to shrink RDD[Vector]
1 parent 8c6c0e1 commit 54b19ab

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,18 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
9393

9494
def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))
9595

96-
def rowShrink(): RDD[Vector] = {
96+
def rowShrink(): RDD[Vector] = self.filter(x => x.toArray.sum != 0)
97+
98+
def colShrink(): RDD[Vector] = {
99+
val means = self.colMeans()
100+
self.map( v => Vectors.dense(v.toArray.zip(means.toArray).filter{ case (x, m) => m != 0.0 }.map(_._1)))
101+
}
102+
103+
def colShrinkWithFilter(): (RDD[Vector], RDD[Boolean]) = {
97104
???
98105
}
99106

100-
def colShrink(): RDD[Vector] = {
107+
def rowShrinkWithFilter(): (RDD[Vector], RDD[Boolean]) = {
101108
???
102109
}
103110
}

0 commit comments

Comments
 (0)