Skip to content

Commit a0e0054

Browse files
committed
update StandardScaler to use SimpleTransformer
1 parent d0faa04 commit a0e0054

File tree

4 files changed

+34
-23
lines changed

4 files changed

+34
-23
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ package org.apache.spark.examples.ml
1919

2020
import scala.beans.BeanInfo
2121

22+
import org.apache.spark.{SparkConf, SparkContext}
23+
import org.apache.spark.ml.{Pipeline, SimpleTransformer}
2224
import org.apache.spark.ml.classification.LogisticRegression
2325
import org.apache.spark.ml.feature.HashingTF
24-
import org.apache.spark.ml.{Pipeline, SimpleTransformer}
26+
import org.apache.spark.ml.param.ParamMap
2527
import org.apache.spark.sql.SQLContext
26-
import org.apache.spark.{SparkConf, SparkContext}
2728

2829
@BeanInfo
2930
case class LabeledDocument(id: Long, text: String, label: Double)
@@ -36,7 +37,8 @@ case class Document(id: Long, text: String)
3637
*/
3738
class SimpleTokenizer extends SimpleTransformer[String, Seq[String], SimpleTokenizer]
3839
with Serializable {
39-
override def createTransformFunc: String => Seq[String] = _.toLowerCase.split("\\s")
40+
override def createTransformFunc(paramMap: ParamMap): String => Seq[String] =
41+
_.toLowerCase.split("\\s")
4042
}
4143

4244
/**

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,17 @@ abstract class SimpleTransformer[IN, OUT: TypeTag, SELF <: SimpleTransformer[IN,
7474
def setInputCol(value: String): SELF = { set(inputCol, value); this.asInstanceOf[SELF] }
7575
def setOutputCol(value: String): SELF = { set(outputCol, value); this.asInstanceOf[SELF] }
7676

77-
def createTransformFunc: IN => OUT
77+
/**
78+
* Creates the transform function using the given param map. The input param map already takes
79+
* account of the embedded param map. So the param values should be determined solely by the input
80+
* param map.
81+
*/
82+
protected def createTransformFunc(paramMap: ParamMap): IN => OUT
7883

7984
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
8085
import dataset.sqlContext._
8186
val map = this.paramMap ++ paramMap
82-
val udf: IN => OUT = this.createTransformFunc
87+
val udf: IN => OUT = this.createTransformFunc(map)
8388
dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
8489
}
8590
}
Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,18 @@
11
package org.apache.spark.ml.feature
22

3-
import org.apache.spark.ml.Transformer
4-
import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap}
3+
import org.apache.spark.ml.SimpleTransformer
4+
import org.apache.spark.ml.param.{IntParam, ParamMap}
55
import org.apache.spark.mllib.feature
66
import org.apache.spark.mllib.linalg.Vector
7-
import org.apache.spark.sql.SchemaRDD
8-
import org.apache.spark.sql.catalyst.analysis.Star
9-
import org.apache.spark.sql.catalyst.dsl._
107

11-
class HashingTF extends Transformer with HasInputCol with HasOutputCol {
12-
13-
def setInputCol(value: String) = { set(inputCol, value); this }
14-
def setOutputCol(value: String) = { set(outputCol, value); this }
8+
class HashingTF extends SimpleTransformer[Iterable[_], Vector, HashingTF] {
159

1610
val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
1711
def setNumFeatures(value: Int) = { set(numFeatures, value); this }
1812
def getNumFeatures: Int = get(numFeatures)
1913

20-
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
21-
import dataset.sqlContext._
22-
val map = this.paramMap ++ paramMap
23-
val hashingTF = new feature.HashingTF(map(numFeatures))
24-
val t: Iterable[_] => Vector = (doc) => {
25-
hashingTF.transform(doc)
26-
}
27-
dataset.select(Star(None), t.call(map(inputCol).attr) as map(outputCol))
14+
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
15+
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
16+
hashingTF.transform
2817
}
2918
}

mllib/src/main/scala/org/apache/spark/ml/package.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,19 @@
1-
// to fail jenkins
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
217

318
package org.apache.spark
419

0 commit comments

Comments
 (0)