1717
1818package org .apache .spark .mllib .feature
1919
20- import breeze .linalg .{DenseVector => BDV , SparseVector => BSV }
21-
2220import org .apache .spark .Logging
2321import org .apache .spark .annotation .Experimental
24- import org .apache .spark .mllib .linalg .{Vector , Vectors }
22+ import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector , Vectors }
2523import org .apache .spark .mllib .rdd .RDDFunctions ._
2624import org .apache .spark .mllib .stat .MultivariateOnlineSummarizer
2725import org .apache .spark .rdd .RDD
@@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] (
7775
7876 require(mean.size == variance.size)
7977
80- private lazy val factor : BDV [Double ] = {
81- val f = BDV .zeros [Double ](variance.size)
78+ private lazy val factor : Array [Double ] = {
79+ val f = Array .ofDim [Double ](variance.size)
8280 var i = 0
8381 while (i < f.size) {
8482 f(i) = if (variance(i) != 0.0 ) 1.0 / math.sqrt(variance(i)) else 0.0
@@ -87,6 +85,8 @@ class StandardScalerModel private[mllib] (
8785 f
8886 }
8987
88+ private lazy val shift : Array [Double ] = mean.toArray
89+
9090 /**
9191 * Applies standardization transformation on a vector.
9292 *
@@ -96,31 +96,48 @@ class StandardScalerModel private[mllib] (
9696 */
9797 override def transform (vector : Vector ): Vector = {
9898 require(mean.size == vector.size)
99+ val localFactor = factor
99100 if (withMean) {
100- vector.toBreeze match {
101- case dv : BDV [Double ] =>
102- val output = vector.toBreeze.copy
101+ val localShift = shift
102+ vector match {
103+ case dv : DenseVector =>
104+ val values = dv.values.clone()
103105 var i = 0
104- while (i < output.length) {
105- output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0 )
106- i += 1
106+ if (withStd) {
107+ while (i < values.length) {
108+ values(i) = (values(i) - localShift(i)) * localFactor(i)
109+ i += 1
110+ }
111+ } else {
112+ while (i < values.length) {
113+ values(i) -= localShift(i)
114+ i += 1
115+ }
107116 }
108- Vectors .fromBreeze(output )
117+ Vectors .dense(values )
109118 case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
110119 }
111120 } else if (withStd) {
112- vector.toBreeze match {
113- case dv : BDV [Double ] => Vectors .fromBreeze(dv :* factor)
114- case sv : BSV [Double ] =>
121+ vector match {
122+ case dv : DenseVector =>
123+ val values = dv.values.clone()
124+ var i = 0
125+ while (i < values.length) {
126+ values(i) *= localFactor(i)
127+ i += 1
128+ }
129+ Vectors .dense(values)
130+ case sv : SparseVector =>
115131 // For sparse vector, the `index` array inside sparse vector object will not be changed,
116132 // so we can re-use it to save memory.
117- val output = new BSV [Double ](sv.index, sv.data.clone(), sv.length)
133+ val indices = sv.indices
134+ val values = sv.values.clone()
118135 var i = 0
119- while (i < output.data .length) {
120- output.data (i) *= factor(output.index (i))
136+ while (i < indices .length) {
137+ values (i) *= localFactor(indices (i))
121138 i += 1
122139 }
123- Vectors .fromBreeze(output )
140+ Vectors .sparse(sv.size, indices, values )
124141 case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
125142 }
126143 } else {
0 commit comments