Skip to content

Commit 1c182dd

Browse files
committed
SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer
1 parent f32e69e commit 1c182dd

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.annotation.AlphaComponent
21+
import org.apache.spark.ml.Transformer
22+
import org.apache.spark.ml.attribute.NominalAttribute
23+
import org.apache.spark.ml.param._
24+
import org.apache.spark.sql.DataFrame
25+
import org.apache.spark.sql.functions._
26+
import org.apache.spark.sql.types.{StringType, StructType}
27+
28+
@AlphaComponent
29+
class OneHotEncoder(labelNames: Seq[String], includeFirst: Boolean = true) extends Transformer
30+
with HasInputCol {
31+
32+
/** @group setParam */
33+
def setInputCol(value: String): this.type = set(inputCol, value)
34+
35+
private def outputColName(index: Int): String = {
36+
s"${get(inputCol)}_${labelNames(index)}"
37+
}
38+
39+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
40+
val map = this.paramMap ++ paramMap
41+
42+
val startIndex = if (includeFirst) 0 else 1
43+
val cols = (startIndex until labelNames.length).map { index =>
44+
val colEncoder = udf { label: Double => if (index == label) 1.0 else 0.0 }
45+
colEncoder(dataset(map(inputCol))).as(outputColName(index))
46+
}
47+
48+
dataset.select(Array(col("*")) ++ cols: _*)
49+
}
50+
51+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
52+
val map = this.paramMap ++ paramMap
53+
checkInputColumn(schema, map(inputCol), StringType)
54+
val inputFields = schema.fields
55+
val startIndex = if (includeFirst) 0 else 1
56+
val fields = (startIndex until labelNames.length).map { index =>
57+
val colName = outputColName(index)
58+
require(inputFields.forall(_.name != colName),
59+
s"Output column $colName already exists.")
60+
NominalAttribute.defaultAttr.withName(colName).toStructField()
61+
}
62+
63+
val outputFields = inputFields ++ fields
64+
StructType(outputFields)
65+
}
66+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.mllib.util.MLlibTestSparkContext
21+
22+
import org.scalatest.FunSuite
23+
import org.apache.spark.sql.SQLContext
24+
import org.apache.spark.ml.attribute.{NominalAttribute, Attribute}
25+
26+
class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
27+
private var sqlContext: SQLContext = _
28+
29+
override def beforeAll(): Unit = {
30+
super.beforeAll()
31+
sqlContext = new SQLContext(sc)
32+
}
33+
34+
test("OneHotEncoder") {
35+
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
36+
val df = sqlContext.createDataFrame(data).toDF("id", "label")
37+
val indexer = new StringIndexer()
38+
.setInputCol("label")
39+
.setOutputCol("labelIndex")
40+
.fit(df)
41+
val transformed = indexer.transform(df)
42+
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
43+
.asInstanceOf[NominalAttribute]
44+
assert(attr.values.get === Array("a", "c", "b"))
45+
46+
val encoder = new OneHotEncoder(attr.values.get)
47+
.setInputCol("labelIndex")
48+
val encoded = encoder.transform(transformed)
49+
50+
val output = encoded.select("id", "labelIndex_a", "labelIndex_c", "labelIndex_b").map { r =>
51+
(r.getInt(0), r.getDouble(1), r.getDouble(2), r.getDouble(3))
52+
}.collect().toSet
53+
// a -> 0, b -> 2, c -> 1
54+
val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
55+
(3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
56+
assert(output === expected)
57+
}
58+
59+
}

0 commit comments

Comments
 (0)