Skip to content

Commit 5bffd3d

Browse files
author
DB Tsai
committed
first commit
1 parent cb0e9b0 commit 5bffd3d

File tree

1 file changed

+36
-19
lines changed

1 file changed

+36
-19
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
package org.apache.spark.mllib.feature
1919

20-
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
21-
2220
import org.apache.spark.Logging
2321
import org.apache.spark.annotation.Experimental
24-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
22+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2523
import org.apache.spark.mllib.rdd.RDDFunctions._
2624
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
2725
import org.apache.spark.rdd.RDD
@@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] (
7775

7876
require(mean.size == variance.size)
7977

80-
private lazy val factor: BDV[Double] = {
81-
val f = BDV.zeros[Double](variance.size)
78+
private lazy val factor: Array[Double] = {
79+
val f = Array.ofDim[Double](variance.size)
8280
var i = 0
8381
while (i < f.size) {
8482
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
@@ -87,6 +85,8 @@ class StandardScalerModel private[mllib] (
8785
f
8886
}
8987

88+
private lazy val shift: Array[Double] = mean.toArray
89+
9090
/**
9191
* Applies standardization transformation on a vector.
9292
*
@@ -96,31 +96,48 @@ class StandardScalerModel private[mllib] (
9696
*/
9797
override def transform(vector: Vector): Vector = {
9898
require(mean.size == vector.size)
99+
val localFactor = factor
99100
if (withMean) {
100-
vector.toBreeze match {
101-
case dv: BDV[Double] =>
102-
val output = vector.toBreeze.copy
101+
val localShift = shift
102+
vector match {
103+
case dv: DenseVector =>
104+
val values = dv.values.clone()
103105
var i = 0
104-
while (i < output.length) {
105-
output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0)
106-
i += 1
106+
if(withStd) {
107+
while (i < values.length) {
108+
values(i) = (values(i) - localShift(i)) * localFactor(i)
109+
i += 1
110+
}
111+
} else {
112+
while (i < values.length) {
113+
values(i) -= localShift(i)
114+
i += 1
115+
}
107116
}
108-
Vectors.fromBreeze(output)
117+
Vectors.dense(values)
109118
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
110119
}
111120
} else if (withStd) {
112-
vector.toBreeze match {
113-
case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor)
114-
case sv: BSV[Double] =>
121+
vector match {
122+
case dv: DenseVector =>
123+
val values = dv.values.clone()
124+
var i = 0
125+
while(i < values.length) {
126+
values(i) *= localFactor(i)
127+
i += 1
128+
}
129+
Vectors.dense(values)
130+
case sv: SparseVector =>
115131
// For sparse vector, the `index` array inside sparse vector object will not be changed,
116132
// so we can re-use it to save memory.
117-
val output = new BSV[Double](sv.index, sv.data.clone(), sv.length)
133+
val indices = sv.indices
134+
val values = sv.values.clone()
118135
var i = 0
119-
while (i < output.data.length) {
120-
output.data(i) *= factor(output.index(i))
136+
while (i < indices.length) {
137+
values(i) *= localFactor(indices(i))
121138
i += 1
122139
}
123-
Vectors.fromBreeze(output)
140+
Vectors.sparse(sv.size, indices, values)
124141
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
125142
}
126143
} else {

0 commit comments

Comments
 (0)