Skip to content

Commit b743664

Browse files
viiryaNick Pentreath
authored andcommitted
[SPARK-23048][ML] Add OneHotEncoderEstimator document and examples
## What changes were proposed in this pull request? We have `OneHotEncoderEstimator` now and `OneHotEncoder` will be deprecated since 2.3.0. We should add `OneHotEncoderEstimator` into mllib document. We also need to provide corresponding examples for `OneHotEncoderEstimator` which are used in the document too. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh <[email protected]> Closes #20257 from viirya/SPARK-23048.
1 parent 60203fc commit b743664

File tree

4 files changed

+68
-70
lines changed

4 files changed

+68
-70
lines changed

docs/ml-features.md

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -775,35 +775,43 @@ for more details on the API.
775775
</div>
776776
</div>
777777

778-
## OneHotEncoder
778+
## OneHotEncoder (Deprecated since 2.3.0)
779779

780-
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features.
780+
Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030).
781+
782+
`OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead.
783+
784+
## OneHotEncoderEstimator
785+
786+
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first.
787+
788+
`OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using [VectorAssembler](ml-features.html#vectorassembler).
789+
790+
`OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error).
781791

782792
**Examples**
783793

784794
<div class="codetabs">
785795
<div data-lang="scala" markdown="1">
786796

787-
Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder)
788-
for more details on the API.
797+
Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API.
789798

790-
{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %}
799+
{% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %}
791800
</div>
792801

793802
<div data-lang="java" markdown="1">
794803

795-
Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html)
804+
Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html)
796805
for more details on the API.
797806

798-
{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %}
807+
{% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %}
799808
</div>
800809

801810
<div data-lang="python" markdown="1">
802811

803-
Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder)
804-
for more details on the API.
812+
Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API.
805813

806-
{% include_example python/ml/onehot_encoder_example.py %}
814+
{% include_example python/ml/onehot_encoder_estimator_example.py %}
807815
</div>
808816
</div>
809817

examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java renamed to examples/src/main/java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323
import java.util.Arrays;
2424
import java.util.List;
2525

26-
import org.apache.spark.ml.feature.OneHotEncoder;
27-
import org.apache.spark.ml.feature.StringIndexer;
28-
import org.apache.spark.ml.feature.StringIndexerModel;
26+
import org.apache.spark.ml.feature.OneHotEncoderEstimator;
27+
import org.apache.spark.ml.feature.OneHotEncoderModel;
2928
import org.apache.spark.sql.Dataset;
3029
import org.apache.spark.sql.Row;
3130
import org.apache.spark.sql.RowFactory;
@@ -35,41 +34,37 @@
3534
import org.apache.spark.sql.types.StructType;
3635
// $example off$
3736

38-
public class JavaOneHotEncoderExample {
37+
public class JavaOneHotEncoderEstimatorExample {
3938
public static void main(String[] args) {
4039
SparkSession spark = SparkSession
4140
.builder()
42-
.appName("JavaOneHotEncoderExample")
41+
.appName("JavaOneHotEncoderEstimatorExample")
4342
.getOrCreate();
4443

44+
// Note: categorical features are usually first encoded with StringIndexer
4545
// $example on$
4646
List<Row> data = Arrays.asList(
47-
RowFactory.create(0, "a"),
48-
RowFactory.create(1, "b"),
49-
RowFactory.create(2, "c"),
50-
RowFactory.create(3, "a"),
51-
RowFactory.create(4, "a"),
52-
RowFactory.create(5, "c")
47+
RowFactory.create(0.0, 1.0),
48+
RowFactory.create(1.0, 0.0),
49+
RowFactory.create(2.0, 1.0),
50+
RowFactory.create(0.0, 2.0),
51+
RowFactory.create(0.0, 1.0),
52+
RowFactory.create(2.0, 0.0)
5353
);
5454

5555
StructType schema = new StructType(new StructField[]{
56-
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
57-
new StructField("category", DataTypes.StringType, false, Metadata.empty())
56+
new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()),
57+
new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty())
5858
});
5959

6060
Dataset<Row> df = spark.createDataFrame(data, schema);
6161

62-
StringIndexerModel indexer = new StringIndexer()
63-
.setInputCol("category")
64-
.setOutputCol("categoryIndex")
65-
.fit(df);
66-
Dataset<Row> indexed = indexer.transform(df);
62+
OneHotEncoderEstimator encoder = new OneHotEncoderEstimator()
63+
.setInputCols(new String[] {"categoryIndex1", "categoryIndex2"})
64+
.setOutputCols(new String[] {"categoryVec1", "categoryVec2"});
6765

68-
OneHotEncoder encoder = new OneHotEncoder()
69-
.setInputCol("categoryIndex")
70-
.setOutputCol("categoryVec");
71-
72-
Dataset<Row> encoded = encoder.transform(indexed);
66+
OneHotEncoderModel model = encoder.fit(df);
67+
Dataset<Row> encoded = model.transform(df);
7368
encoded.show();
7469
// $example off$
7570

examples/src/main/python/ml/onehot_encoder_example.py renamed to examples/src/main/python/ml/onehot_encoder_estimator_example.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,31 @@
1818
from __future__ import print_function
1919

2020
# $example on$
21-
from pyspark.ml.feature import OneHotEncoder, StringIndexer
21+
from pyspark.ml.feature import OneHotEncoderEstimator
2222
# $example off$
2323
from pyspark.sql import SparkSession
2424

2525
if __name__ == "__main__":
2626
spark = SparkSession\
2727
.builder\
28-
.appName("OneHotEncoderExample")\
28+
.appName("OneHotEncoderEstimatorExample")\
2929
.getOrCreate()
3030

31+
# Note: categorical features are usually first encoded with StringIndexer
3132
# $example on$
3233
df = spark.createDataFrame([
33-
(0, "a"),
34-
(1, "b"),
35-
(2, "c"),
36-
(3, "a"),
37-
(4, "a"),
38-
(5, "c")
39-
], ["id", "category"])
34+
(0.0, 1.0),
35+
(1.0, 0.0),
36+
(2.0, 1.0),
37+
(0.0, 2.0),
38+
(0.0, 1.0),
39+
(2.0, 0.0)
40+
], ["categoryIndex1", "categoryIndex2"])
4041

41-
stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
42-
model = stringIndexer.fit(df)
43-
indexed = model.transform(df)
44-
45-
encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec")
46-
encoded = encoder.transform(indexed)
42+
encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"],
43+
outputCols=["categoryVec1", "categoryVec2"])
44+
model = encoder.fit(df)
45+
encoded = model.transform(df)
4746
encoded.show()
4847
# $example off$
4948

examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala renamed to examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,34 @@
1919
package org.apache.spark.examples.ml
2020

2121
// $example on$
22-
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
22+
import org.apache.spark.ml.feature.OneHotEncoderEstimator
2323
// $example off$
2424
import org.apache.spark.sql.SparkSession
2525

26-
object OneHotEncoderExample {
26+
object OneHotEncoderEstimatorExample {
2727
def main(args: Array[String]): Unit = {
2828
val spark = SparkSession
2929
.builder
30-
.appName("OneHotEncoderExample")
30+
.appName("OneHotEncoderEstimatorExample")
3131
.getOrCreate()
3232

33+
// Note: categorical features are usually first encoded with StringIndexer
3334
// $example on$
3435
val df = spark.createDataFrame(Seq(
35-
(0, "a"),
36-
(1, "b"),
37-
(2, "c"),
38-
(3, "a"),
39-
(4, "a"),
40-
(5, "c")
41-
)).toDF("id", "category")
42-
43-
val indexer = new StringIndexer()
44-
.setInputCol("category")
45-
.setOutputCol("categoryIndex")
46-
.fit(df)
47-
val indexed = indexer.transform(df)
48-
49-
val encoder = new OneHotEncoder()
50-
.setInputCol("categoryIndex")
51-
.setOutputCol("categoryVec")
52-
53-
val encoded = encoder.transform(indexed)
36+
(0.0, 1.0),
37+
(1.0, 0.0),
38+
(2.0, 1.0),
39+
(0.0, 2.0),
40+
(0.0, 1.0),
41+
(2.0, 0.0)
42+
)).toDF("categoryIndex1", "categoryIndex2")
43+
44+
val encoder = new OneHotEncoderEstimator()
45+
.setInputCols(Array("categoryIndex1", "categoryIndex2"))
46+
.setOutputCols(Array("categoryVec1", "categoryVec2"))
47+
val model = encoder.fit(df)
48+
49+
val encoded = model.transform(df)
5450
encoded.show()
5551
// $example off$
5652

0 commit comments

Comments
 (0)