Skip to content

Commit bf1a6aa

Browse files
DB Tsaimengxr
authored andcommitted
[SPARK-4581][MLlib] Refactorize StandardScaler to improve the transformation performance
The following optimizations are done to improve the StandardScaler model transformation performance. 1) Covert Breeze dense vector to primitive vector to reduce the overhead. 2) Since mean can be potentially a sparse vector, we explicitly convert it to dense primitive vector. 3) Have a local reference to `shift` and `factor` array so JVM can locate the value with one operation call. 4) In pattern matching part, we use the mllib SparseVector/DenseVector instead of breeze's vector to make the codebase cleaner. Benchmark with mnist8m dataset: Before, DenseVector withMean and withStd: 50.97secs DenseVector withMean and withoutStd: 42.11secs DenseVector withoutMean and withStd: 8.75secs SparseVector withoutMean and withStd: 5.437secs With this PR, DenseVector withMean and withStd: 5.76secs DenseVector withMean and withoutStd: 5.28secs DenseVector withoutMean and withStd: 5.30secs SparseVector withoutMean and withStd: 1.27secs Note that without the local reference copy of `factor` and `shift` arrays, the runtime is almost three time slower. DenseVector withMean and withStd: 18.15secs DenseVector withMean and withoutStd: 18.05secs DenseVector withoutMean and withStd: 18.54secs SparseVector withoutMean and withStd: 2.01secs The following code, ```scala while (i < size) { values(i) = (values(i) - shift(i)) * factor(i) i += 1 } ``` will generate the bytecode ``` L13 LINENUMBER 106 L13 FRAME FULL [org/apache/spark/mllib/feature/StandardScalerModel org/apache/spark/mllib/linalg/Vector org/apache/spark/mllib/linalg/Vector org/apache/spark/mllib/linalg/DenseVector T [D I I] [] ILOAD 7 ILOAD 6 IF_ICMPGE L14 L15 LINENUMBER 107 L15 ALOAD 5 ILOAD 7 ALOAD 5 ILOAD 7 DALOAD ALOAD 0 INVOKESPECIAL org/apache/spark/mllib/feature/StandardScalerModel.shift ()[D ILOAD 7 DALOAD DSUB ALOAD 0 INVOKESPECIAL org/apache/spark/mllib/feature/StandardScalerModel.factor ()[D ILOAD 7 DALOAD DMUL DASTORE L16 LINENUMBER 108 L16 ILOAD 7 ICONST_1 IADD ISTORE 7 GOTO L13 ``` , while with the local reference of the `shift` and `factor` arrays, the bytecode will be ``` L14 LINENUMBER 107 L14 ALOAD 0 INVOKESPECIAL org/apache/spark/mllib/feature/StandardScalerModel.factor ()[D ASTORE 9 L15 LINENUMBER 108 L15 FRAME FULL [org/apache/spark/mllib/feature/StandardScalerModel org/apache/spark/mllib/linalg/Vector [D org/apache/spark/mllib/linalg/Vector org/apache/spark/mllib/linalg/DenseVector T [D I I [D] [] ILOAD 8 ILOAD 7 IF_ICMPGE L16 L17 LINENUMBER 109 L17 ALOAD 6 ILOAD 8 ALOAD 6 ILOAD 8 DALOAD ALOAD 2 ILOAD 8 DALOAD DSUB ALOAD 9 ILOAD 8 DALOAD DMUL DASTORE L18 LINENUMBER 110 L18 ILOAD 8 ICONST_1 IADD ISTORE 8 GOTO L15 ``` You can see that with local reference, the both of the arrays will be in the stack, so JVM can access the value without calling `INVOKESPECIAL`. Author: DB Tsai <[email protected]> Closes #3435 from dbtsai/standardscaler and squashes the following commits: 85885a9 [DB Tsai] revert to have lazy in shift array. daf2b06 [DB Tsai] Address the feedback cdb5cef [DB Tsai] small change 9c51eef [DB Tsai] style fc795e4 [DB Tsai] update 5bffd3d [DB Tsai] first commit
1 parent 69cd53e commit bf1a6aa

File tree

1 file changed

+50
-20
lines changed

1 file changed

+50
-20
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
package org.apache.spark.mllib.feature
1919

20-
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
21-
2220
import org.apache.spark.Logging
2321
import org.apache.spark.annotation.Experimental
24-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
22+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2523
import org.apache.spark.mllib.rdd.RDDFunctions._
2624
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
2725
import org.apache.spark.rdd.RDD
@@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] (
7775

7876
require(mean.size == variance.size)
7977

80-
private lazy val factor: BDV[Double] = {
81-
val f = BDV.zeros[Double](variance.size)
78+
private lazy val factor: Array[Double] = {
79+
val f = Array.ofDim[Double](variance.size)
8280
var i = 0
8381
while (i < f.size) {
8482
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
@@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] (
8785
f
8886
}
8987

88+
// Since `shift` will be only used in `withMean` branch, we have it as
89+
// `lazy val` so it will be evaluated in that branch. Note that we don't
90+
// want to create this array multiple times in `transform` function.
91+
private lazy val shift: Array[Double] = mean.toArray
92+
9093
/**
9194
* Applies standardization transformation on a vector.
9295
*
@@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] (
97100
override def transform(vector: Vector): Vector = {
98101
require(mean.size == vector.size)
99102
if (withMean) {
100-
vector.toBreeze match {
101-
case dv: BDV[Double] =>
102-
val output = vector.toBreeze.copy
103-
var i = 0
104-
while (i < output.length) {
105-
output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0)
106-
i += 1
103+
// By default, Scala generates Java methods for member variables. So every time when
104+
// the member variables are accessed, `invokespecial` will be called which is expensive.
105+
// This can be avoid by having a local reference of `shift`.
106+
val localShift = shift
107+
vector match {
108+
case dv: DenseVector =>
109+
val values = dv.values.clone()
110+
val size = values.size
111+
if (withStd) {
112+
// Having a local reference of `factor` to avoid overhead as the comment before.
113+
val localFactor = factor
114+
var i = 0
115+
while (i < size) {
116+
values(i) = (values(i) - localShift(i)) * localFactor(i)
117+
i += 1
118+
}
119+
} else {
120+
var i = 0
121+
while (i < size) {
122+
values(i) -= localShift(i)
123+
i += 1
124+
}
107125
}
108-
Vectors.fromBreeze(output)
126+
Vectors.dense(values)
109127
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
110128
}
111129
} else if (withStd) {
112-
vector.toBreeze match {
113-
case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor)
114-
case sv: BSV[Double] =>
130+
// Having a local reference of `factor` to avoid overhead as the comment before.
131+
val localFactor = factor
132+
vector match {
133+
case dv: DenseVector =>
134+
val values = dv.values.clone()
135+
val size = values.size
136+
var i = 0
137+
while(i < size) {
138+
values(i) *= localFactor(i)
139+
i += 1
140+
}
141+
Vectors.dense(values)
142+
case sv: SparseVector =>
115143
// For sparse vector, the `index` array inside sparse vector object will not be changed,
116144
// so we can re-use it to save memory.
117-
val output = new BSV[Double](sv.index, sv.data.clone(), sv.length)
145+
val indices = sv.indices
146+
val values = sv.values.clone()
147+
val nnz = values.size
118148
var i = 0
119-
while (i < output.data.length) {
120-
output.data(i) *= factor(output.index(i))
149+
while (i < nnz) {
150+
values(i) *= localFactor(indices(i))
121151
i += 1
122152
}
123-
Vectors.fromBreeze(output)
153+
Vectors.sparse(sv.size, indices, values)
124154
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
125155
}
126156
} else {

0 commit comments

Comments
 (0)