Skip to content

Commit e09d5d2

Browse files
committed
add scala docs and refine shrink method
1 parent 8ef3377 commit e09d5d2

File tree

1 file changed

+59
-4
lines changed

1 file changed

+59
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package org.apache.spark.mllib.rdd
1818

19-
import breeze.linalg.{Vector => BV, *}
19+
import breeze.linalg.{Vector => BV, DenseVector => BDV}
2020

2121
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2222
import org.apache.spark.mllib.util.MLUtils._
@@ -28,23 +28,38 @@ import org.apache.spark.rdd.RDD
2828
*/
2929
class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
3030

31+
/**
32+
* Compute the mean of each `Vector` in the RDD.
33+
*/
3134
def rowMeans(): RDD[Double] = {
3235
self.map(x => x.toArray.sum / x.size)
3336
}
3437

38+
/**
39+
* Compute the norm-2 of each `Vector` in the RDD.
40+
*/
3541
def rowNorm2(): RDD[Double] = {
3642
self.map(x => math.sqrt(x.toArray.map(x => x*x).sum))
3743
}
3844

45+
/**
46+
* Compute the standard deviation of each `Vector` in the RDD.
47+
*/
3948
def rowSDs(): RDD[Double] = {
4049
val means = self.rowMeans()
4150
self.zip(means)
4251
.map{ case(x, m) => x.toBreeze - m }
4352
.map{ x => math.sqrt(x.toArray.map(x => x*x).sum / x.size) }
4453
}
4554

55+
/**
56+
* Compute the mean of each column in the RDD.
57+
*/
4658
def colMeans(): Vector = colMeans(self.take(1).head.size)
4759

60+
/**
61+
* Compute the mean of each column in the RDD with `size` as the dimension of each `Vector`.
62+
*/
4863
def colMeans(size: Int): Vector = {
4964
Vectors.fromBreeze(self.map(_.toBreeze).aggregate((BV.zeros[Double](size), 0.0))(
5065
seqOp = (c, v) => (c, v) match {
@@ -58,15 +73,27 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
5873
)._1)
5974
}
6075

76+
/**
77+
* Compute the norm-2 of each column in the RDD.
78+
*/
6179
def colNorm2(): Vector = colNorm2(self.take(1).head.size)
6280

81+
/**
82+
* Compute the norm-2 of each column in the RDD with `size` as the dimension of each `Vector`.
83+
*/
6384
def colNorm2(size: Int): Vector = Vectors.fromBreeze(self.map(_.toBreeze).aggregate(BV.zeros[Double](size))(
6485
seqOp = (c, v) => c + (v :* v),
6586
combOp = (lhs, rhs) => lhs + rhs
6687
).map(math.sqrt))
6788

89+
/**
90+
* Compute the standard deviation of each column in the RDD.
91+
*/
6892
def colSDs(): Vector = colSDs(self.take(1).head.size)
6993

94+
/**
95+
* Compute the standard deviation of each column in the RDD with `size` as the dimension of each `Vector`.
96+
*/
7097
def colSDs(size: Int): Vector = {
7198
val means = self.colMeans()
7299
Vectors.fromBreeze(self.map(x => x.toBreeze - means.toBreeze).aggregate((BV.zeros[Double](size), 0.0))(
@@ -81,21 +108,49 @@ class VectorRDDFunctions(self: RDD[Vector]) extends Serializable {
81108
)._1.map(math.sqrt))
82109
}
83110

111+
/**
112+
* Find the optional max or min vector in the RDD.
113+
*/
84114
private def maxMinOption(cmp: (Vector, Vector) => Boolean): Option[Vector] = {
85115
def cmpMaxMin(x1: Vector, x2: Vector) = if (cmp(x1, x2)) x1 else x2
86116
self.mapPartitions { iterator =>
87117
Seq(iterator.reduceOption(cmpMaxMin)).iterator
88118
}.collect { case Some(x) => x }.collect().reduceOption(cmpMaxMin)
89119
}
90120

121+
/**
122+
* Find the optional max vector in the RDD, `None` will be returned if there is no elements at all.
123+
*/
91124
def maxOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(cmp)
92125

126+
/**
127+
* Find the optional min vector in the RDD, `None` will be returned if there is no elements at all.
128+
*/
93129
def minOption(cmp: (Vector, Vector) => Boolean) = maxMinOption(!cmp(_, _))
94130

95-
def rowShrink(): RDD[Vector] = self.filter(x => x.toArray.sum != 0)
131+
/**
132+
* Filter the vectors whose standard deviation is not zero.
133+
*/
134+
def rowShrink(): RDD[Vector] = self.zip(self.rowSDs()).filter(_._2 != 0.0).map(_._1)
96135

136+
/**
137+
* Filter each column of the RDD whose standard deviation is not zero.
138+
*/
97139
def colShrink(): RDD[Vector] = {
98-
val means = self.colMeans()
99-
self.map( v => Vectors.dense(v.toArray.zip(means.toArray).filter{ case (x, m) => m != 0.0 }.map(_._1)))
140+
val sds = self.colSDs()
141+
self.take(1).head.toBreeze.isInstanceOf[BDV[Double]] match {
142+
case true =>
143+
self.map{ v =>
144+
Vectors.dense(v.toArray.zip(sds.toArray).filter{case (x, m) => m != 0.0}.map(_._1))
145+
}
146+
case false =>
147+
self.map { v =>
148+
val filtered = v.toArray.zip(sds.toArray).filter{case (x, m) => m != 0.0}.map(_._1)
149+
val denseVector = Vectors.dense(filtered).toBreeze
150+
val size = denseVector.size
151+
val iterElement = denseVector.activeIterator.toSeq
152+
Vectors.sparse(size, iterElement)
153+
}
154+
}
100155
}
101156
}

0 commit comments

Comments
 (0)