Skip to content

Commit e07baf1

Browse files
committed
[SPARK-17001][ML] Enable standardScaler to standardize sparse vectors when withMean=True
## What changes were proposed in this pull request? Allow centering / mean scaling of sparse vectors in StandardScaler, if requested. This is for compatibility with `VectorAssembler` in common usages. ## How was this patch tested? Jenkins tests, including new caes to reflect the new behavior. Author: Sean Owen <[email protected]> Closes #14663 from srowen/SPARK-17001.
1 parent 9fbced5 commit e07baf1

File tree

9 files changed

+80
-62
lines changed

9 files changed

+80
-62
lines changed

docs/ml-features.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ for more details on the API.
768768
`StandardScaler` transforms a dataset of `Vector` rows, normalizing each feature to have unit standard deviation and/or zero mean. It takes parameters:
769769

770770
* `withStd`: True by default. Scales the data to unit standard deviation.
771-
* `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception.
771+
* `withMean`: False by default. Centers the data with mean before scaling. It will build a dense output, so take care when applying to sparse input.
772772

773773
`StandardScaler` is an `Estimator` which can be `fit` on a dataset to produce a `StandardScalerModel`; this amounts to computing summary statistics. The model can then transform a `Vector` column in a dataset to have unit standard deviation and/or zero mean features.
774774

docs/mllib-feature-extraction.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ against features with very large variances exerting an overly large influence du
148148
following parameters in the constructor:
149149

150150
* `withMean` False by default. Centers the data with mean before scaling. It will build a dense
151-
output, so this does not work on sparse input and will raise an exception.
151+
output, so take care when applying to sparse input.
152152
* `withStd` True by default. Scales the data to unit standard deviation.
153153

154154
We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in

examples/src/main/python/mllib/standard_scaler_example.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838
# data1 will be unit variance.
3939
data1 = label.zip(scaler1.transform(features))
4040

41-
# Without converting the features into dense vectors, transformation with zero mean will raise
42-
# exception on sparse vector.
4341
# data2 will be unit variance and zero mean.
4442
data2 = label.zip(scaler2.transform(features.map(lambda x: Vectors.dense(x.toArray()))))
4543
# $example off$

examples/src/main/scala/org/apache/spark/examples/mllib/StandardScalerExample.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ object StandardScalerExample {
4444
// data1 will be unit variance.
4545
val data1 = data.map(x => (x.label, scaler1.transform(x.features)))
4646

47-
// Without converting the features into dense vectors, transformation with zero mean will raise
48-
// exception on sparse vector.
4947
// data2 will be unit variance and zero mean.
5048
val data2 = data.map(x => (x.label, scaler2.transform(Vectors.dense(x.features.toArray))))
5149
// $example off$

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
4141

4242
/**
4343
* Whether to center the data with mean before scaling.
44-
* It will build a dense output, so this does not work on sparse input
45-
* and will raise an exception.
44+
* It will build a dense output, so take care when applying to sparse input.
4645
* Default: false
4746
* @group param
4847
*/

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.rdd.RDD
3232
* which is computed as the square root of the unbiased sample variance.
3333
*
3434
* @param withMean False by default. Centers the data with mean before scaling. It will build a
35-
* dense output, so this does not work on sparse input and will raise an exception.
35+
* dense output, so take care when applying to sparse input.
3636
* @param withStd True by default. Scales the data to unit standard deviation.
3737
*/
3838
@Since("1.1.0")
@@ -139,26 +139,27 @@ class StandardScalerModel @Since("1.3.0") (
139139
// the member variables are accessed, `invokespecial` will be called which is expensive.
140140
// This can be avoid by having a local reference of `shift`.
141141
val localShift = shift
142-
vector match {
143-
case DenseVector(vs) =>
144-
val values = vs.clone()
145-
val size = values.length
146-
if (withStd) {
147-
var i = 0
148-
while (i < size) {
149-
values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
150-
i += 1
151-
}
152-
} else {
153-
var i = 0
154-
while (i < size) {
155-
values(i) -= localShift(i)
156-
i += 1
157-
}
158-
}
159-
Vectors.dense(values)
160-
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
142+
// Must have a copy of the values since it will be modified in place
143+
val values = vector match {
144+
// specially handle DenseVector because its toArray does not clone already
145+
case d: DenseVector => d.values.clone()
146+
case v: Vector => v.toArray
147+
}
148+
val size = values.length
149+
if (withStd) {
150+
var i = 0
151+
while (i < size) {
152+
values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
153+
i += 1
154+
}
155+
} else {
156+
var i = 0
157+
while (i < size) {
158+
values(i) -= localShift(i)
159+
i += 1
160+
}
161161
}
162+
Vectors.dense(values)
162163
} else if (withStd) {
163164
vector match {
164165
case DenseVector(vs) =>

mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,22 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
114114
assertResult(standardScaler3.transform(df3))
115115
}
116116

117+
test("sparse data and withMean") {
118+
val someSparseData = Array(
119+
Vectors.sparse(3, Array(0, 1), Array(-2.0, 2.3)),
120+
Vectors.sparse(3, Array(1, 2), Array(-5.1, 1.0)),
121+
Vectors.dense(1.7, -0.6, 3.3)
122+
)
123+
val df = spark.createDataFrame(someSparseData.zip(resWithMean)).toDF("features", "expected")
124+
val standardScaler = new StandardScaler()
125+
.setInputCol("features")
126+
.setOutputCol("standardized_features")
127+
.setWithMean(true)
128+
.setWithStd(false)
129+
.fit(df)
130+
assertResult(standardScaler.transform(df))
131+
}
132+
117133
test("StandardScaler read/write") {
118134
val t = new StandardScaler()
119135
.setInputCol("myInputCol")

mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -207,37 +207,41 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
207207
val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
208208
val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
209209

210+
val data1 = sparseData.map(equivalentModel1.transform)
210211
val data2 = sparseData.map(equivalentModel2.transform)
212+
val data3 = sparseData.map(equivalentModel3.transform)
211213

212-
withClue("Standardization with mean can not be applied on sparse input.") {
213-
intercept[IllegalArgumentException] {
214-
sparseData.map(equivalentModel1.transform)
215-
}
216-
}
217-
218-
withClue("Standardization with mean can not be applied on sparse input.") {
219-
intercept[IllegalArgumentException] {
220-
sparseData.map(equivalentModel3.transform)
221-
}
222-
}
223-
214+
val data1RDD = equivalentModel1.transform(dataRDD)
224215
val data2RDD = equivalentModel2.transform(dataRDD)
216+
val data3RDD = equivalentModel3.transform(dataRDD)
225217

226-
val summary = computeSummary(data2RDD)
218+
val summary1 = computeSummary(data1RDD)
219+
val summary2 = computeSummary(data2RDD)
220+
val summary3 = computeSummary(data3RDD)
227221

228222
assert((sparseData, data2, data2RDD.collect()).zipped.forall {
229223
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
230224
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
231225
case _ => false
232226
}, "The vector type should be preserved after standardization.")
233227

228+
assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
234229
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
230+
assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
235231

236-
assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
237-
assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
232+
assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
233+
assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
234+
assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
235+
assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
236+
assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
237+
assert(summary3.variance !~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
238238

239+
assert(data1(4) ~== Vectors.dense(0.56854, -0.069068, 0.116377) absTol 1E-5)
240+
assert(data1(5) ~== Vectors.dense(-0.296998, 0.872775, 0.116377) absTol 1E-5)
239241
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
240242
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
243+
assert(data3(4) ~== Vectors.dense(1.116666, -0.183333, 0.183333) absTol 1E-5)
244+
assert(data3(5) ~== Vectors.dense(-0.583333, 2.316666, 0.183333) absTol 1E-5)
241245
}
242246

243247
test("Standardization with sparse input") {
@@ -252,38 +256,41 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
252256
val model2 = standardizer2.fit(dataRDD)
253257
val model3 = standardizer3.fit(dataRDD)
254258

259+
val data1 = sparseData.map(model1.transform)
255260
val data2 = sparseData.map(model2.transform)
261+
val data3 = sparseData.map(model3.transform)
256262

257-
withClue("Standardization with mean can not be applied on sparse input.") {
258-
intercept[IllegalArgumentException] {
259-
sparseData.map(model1.transform)
260-
}
261-
}
262-
263-
withClue("Standardization with mean can not be applied on sparse input.") {
264-
intercept[IllegalArgumentException] {
265-
sparseData.map(model3.transform)
266-
}
267-
}
268-
263+
val data1RDD = model1.transform(dataRDD)
269264
val data2RDD = model2.transform(dataRDD)
265+
val data3RDD = model3.transform(dataRDD)
270266

271-
272-
val summary = computeSummary(data2RDD)
267+
val summary1 = computeSummary(data1RDD)
268+
val summary2 = computeSummary(data2RDD)
269+
val summary3 = computeSummary(data3RDD)
273270

274271
assert((sparseData, data2, data2RDD.collect()).zipped.forall {
275272
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
276273
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
277274
case _ => false
278275
}, "The vector type should be preserved after standardization.")
279276

277+
assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
280278
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
279+
assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
281280

282-
assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
283-
assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
281+
assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
282+
assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
283+
assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
284+
assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
285+
assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
286+
assert(summary3.variance !~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
284287

288+
assert(data1(4) ~== Vectors.dense(0.56854, -0.069068, 0.116377) absTol 1E-5)
289+
assert(data1(5) ~== Vectors.dense(-0.296998, 0.872775, 0.116377) absTol 1E-5)
285290
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
286291
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
292+
assert(data3(4) ~== Vectors.dense(1.116666, -0.183333, 0.183333) absTol 1E-5)
293+
assert(data3(5) ~== Vectors.dense(-0.583333, 2.316666, 0.183333) absTol 1E-5)
287294
}
288295

289296
test("Standardization with constant input when means and stds are provided") {

python/pyspark/mllib/feature.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,8 @@ class StandardScaler(object):
208208
training set.
209209
210210
:param withMean: False by default. Centers the data with mean
211-
before scaling. It will build a dense output, so this
212-
does not work on sparse input and will raise an
213-
exception.
211+
before scaling. It will build a dense output, so take
212+
care when applying to sparse input.
214213
:param withStd: True by default. Scales the data to unit
215214
standard deviation.
216215

0 commit comments

Comments
 (0)