Skip to content

Commit 6bd0a10

Browse files
committed
merge repeated features
1 parent 419f8a2 commit 6bd0a10

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialMapper.scala

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,40 @@ object PolynomialMapper {
7474
}
7575

7676
/**
77-
* Multiply two polynomials.
77+
* Multiply two polynomials, the first is the original vector, i.e. the expanded vector with
78+
* degree 1, while the second is the expanded vector with degree `currDegree - 1`. A new expanded
79+
* vector with degree `currDegree` will be generated after the function call.
80+
*
81+
* @param lhs original vector with degree 1
82+
* @param rhs expanded vector with degree `currDegree - 1`
83+
* @param nDim the dimension of original vector
84+
* @param currDegree the polynomial degree that need to be achieved
7885
*/
79-
private def expandVector(lhs: Vector, rhs: Vector): Vector = {
86+
private def expandVector(lhs: Vector, rhs: Vector, nDim: Int, currDegree: Int): Vector = {
8087
(lhs, rhs) match {
8188
case (l: DenseVector, r: DenseVector) =>
82-
Vectors.dense(l.toArray.flatMap(lx => r.toArray.map(rx => lx * rx)))
89+
var rightVectorView = rhs.toArray
90+
val allExpansions = l.toArray.zipWithIndex.flatMap { case (lVal, lIdx) =>
91+
val currExpansions = rightVectorView.map(rVal => lVal * rVal)
92+
val numToRemove = numMonomials(currDegree - 1, nDim - lIdx)
93+
rightVectorView = rightVectorView.drop(numToRemove)
94+
currExpansions
95+
}
96+
Vectors.dense(allExpansions)
97+
8398
case (SparseVector(lLen, lIdx, lVal), SparseVector(rLen, rIdx, rVal)) =>
84-
val len = lLen * rLen
85-
val idx = lIdx.flatMap(li => rIdx.map(ri => li * lLen + ri))
86-
val value = lVal.flatMap(lv => rVal.map(rv => lv * rv))
87-
Vectors.sparse(len, idx, value)
99+
val len = numMonomials(currDegree, nDim)
100+
var numToRemoveCum = 0
101+
val allExpansions = lVal.zip(lIdx).flatMap { case (lv, li) =>
102+
val currExpansions = rVal.zip(rIdx).map { case (rv, ri) =>
103+
val realIdx = li * nDim + ri
104+
(if(realIdx > numToRemoveCum) lv * rv else 0.0, realIdx - numToRemoveCum)
105+
}
106+
numToRemoveCum += numMonomials(currDegree - 1, nDim - li)
107+
currExpansions
108+
}
109+
Vectors.sparse(len, allExpansions.map(_._2), allExpansions.map(_._1))
110+
88111
case _ => throw new Exception("vector types are not match.")
89112
}
90113
}
@@ -94,14 +117,15 @@ object PolynomialMapper {
94117
* degree 1 to degree `degree`.
95118
*/
96119
private def transform(degree: Int)(feature: Vector): Vector = {
120+
val nDim = feature.size
97121
feature match {
98122
case f: DenseVector =>
99-
(2 to degree).foldLeft(Array(feature.copy)) { (vectors, _) =>
100-
vectors ++ Array(expandVector(feature, vectors.last))
123+
(2 to degree).foldLeft(Array(feature.copy)) { (vectors, currDegree) =>
124+
vectors ++ Array(expandVector(feature, vectors.last, nDim, currDegree))
101125
}.reduce((lhs, rhs) => Vectors.dense(lhs.toArray ++ rhs.toArray))
102126
case f: SparseVector =>
103-
(2 to degree).foldLeft(Array(feature.copy)) { (vectors, _) =>
104-
vectors ++ Array(expandVector(feature, vectors.last))
127+
(2 to degree).foldLeft(Array(feature.copy)) { (vectors, currDegree) =>
128+
vectors ++ Array(expandVector(feature, vectors.last, nDim, currDegree))
105129
}.reduce { (lhs, rhs) =>
106130
(lhs, rhs) match {
107131
case (SparseVector(lLen, lIdx, lVal), SparseVector(rLen, rIdx, rVal)) =>

mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialMapperSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class PolynomialMapperSuite extends FunSuite with MLlibTestSparkContext {
7979
}
8080

8181
test("fake") {
82-
val result = collectResult(polynomialMapper.setDegree(3).transform(dataFrame))
82+
val result = collectResult(polynomialMapper.setDegree(2).transform(dataFrame))
8383
for(r <- result) {
8484
println(r)
8585
}

0 commit comments

Comments
 (0)