-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23048][ML] Add OneHotEncoderEstimator document and examples #20257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
21cb7d3
13a7b90
262c046
e57d9ee
18cf226
3c697bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -775,7 +775,9 @@ for more details on the API. | |
| </div> | ||
| </div> | ||
|
|
||
| ## OneHotEncoder | ||
| ## OneHotEncoder (Deprecated since 2.3.0) | ||
|
|
||
| `OneHotEncoder` will be deprecated in 2.3.0 and removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. | ||
|
||
|
|
||
| [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. | ||
|
|
||
|
|
@@ -807,6 +809,36 @@ for more details on the API. | |
| </div> | ||
| </div> | ||
|
|
||
| ## OneHotEncoderEstimator | ||
|
|
||
| [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. | ||
|
||
|
|
||
| **Examples** | ||
|
|
||
| <div class="codetabs"> | ||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. | ||
|
|
||
| {% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) | ||
| for more details on the API. | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %} | ||
| </div> | ||
|
|
||
| <div data-lang="python" markdown="1"> | ||
|
|
||
| Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. | ||
|
|
||
| {% include_example python/ml/onehot_encoder_estimator_example.py %} | ||
| </div> | ||
| </div> | ||
|
|
||
| ## VectorIndexer | ||
|
|
||
| `VectorIndexer` helps index categorical features in datasets of `Vector`s. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.examples.ml; | ||
|
|
||
| import org.apache.spark.sql.SparkSession; | ||
|
|
||
| // $example on$ | ||
| import java.util.Arrays; | ||
| import java.util.List; | ||
|
|
||
| import org.apache.spark.ml.feature.OneHotEncoderEstimator; | ||
| import org.apache.spark.ml.feature.OneHotEncoderModel; | ||
| import org.apache.spark.ml.feature.StringIndexer; | ||
| import org.apache.spark.ml.feature.StringIndexerModel; | ||
| import org.apache.spark.sql.Dataset; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.RowFactory; | ||
| import org.apache.spark.sql.types.DataTypes; | ||
| import org.apache.spark.sql.types.Metadata; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
| // $example off$ | ||
|
|
||
| public class JavaOneHotEncoderEstimatorExample { | ||
| public static void main(String[] args) { | ||
| SparkSession spark = SparkSession | ||
| .builder() | ||
| .appName("JavaOneHotEncoderEstimatorExample") | ||
| .getOrCreate(); | ||
|
|
||
| // $example on$ | ||
| List<Row> data = Arrays.asList( | ||
| RowFactory.create(0, "a", "x"), | ||
| RowFactory.create(1, "b", "y"), | ||
| RowFactory.create(2, "c", "y"), | ||
| RowFactory.create(3, "a", "z"), | ||
| RowFactory.create(4, "a", "y"), | ||
| RowFactory.create(5, "c", "z") | ||
| ); | ||
|
|
||
| StructType schema = new StructType(new StructField[]{ | ||
| new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), | ||
| new StructField("category1", DataTypes.StringType, false, Metadata.empty()), | ||
| new StructField("category2", DataTypes.StringType, false, Metadata.empty()) | ||
| }); | ||
|
|
||
| Dataset<Row> df = spark.createDataFrame(data, schema); | ||
|
|
||
| // TODO: Replace this with multi-column API of StringIndexer once SPARK-11215 is merged. | ||
| StringIndexerModel indexer1 = new StringIndexer() | ||
| .setInputCol("category1") | ||
| .setOutputCol("categoryIndex1") | ||
| .fit(df); | ||
| StringIndexerModel indexer2 = new StringIndexer() | ||
| .setInputCol("category2") | ||
| .setOutputCol("categoryIndex2") | ||
| .fit(df); | ||
| Dataset<Row> indexed1 = indexer1.transform(df); | ||
| Dataset<Row> indexed2 = indexer2.transform(indexed1); | ||
|
|
||
| OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() | ||
| .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) | ||
| .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); | ||
|
|
||
| OneHotEncoderModel model = encoder.fit(indexed2); | ||
| Dataset<Row> encoded = model.transform(indexed2); | ||
| encoded.show(); | ||
| // $example off$ | ||
|
|
||
| spark.stop(); | ||
| } | ||
| } | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from __future__ import print_function | ||
|
|
||
| # $example on$ | ||
| from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer | ||
| # $example off$ | ||
| from pyspark.sql import SparkSession | ||
|
|
||
| if __name__ == "__main__": | ||
| spark = SparkSession\ | ||
| .builder\ | ||
| .appName("OneHotEncoderEstimatorExample")\ | ||
| .getOrCreate() | ||
|
|
||
| # $example on$ | ||
| df = spark.createDataFrame([ | ||
| (0, "a", "x"), | ||
| (1, "b", "y"), | ||
| (2, "c", "y"), | ||
| (3, "a", "z"), | ||
| (4, "a", "y"), | ||
| (5, "c", "z") | ||
| ], ["id", "category1", "category2"]) | ||
|
|
||
| # TODO: Replace this with multi-column API of StringIndexer once SPARK-11215 is merged. | ||
| stringIndexer1 = StringIndexer(inputCol="category1", outputCol="categoryIndex1") | ||
| stringIndexer2 = StringIndexer(inputCol="category2", outputCol="categoryIndex2") | ||
|
|
||
| stringIndexerModel1 = stringIndexer1.fit(df) | ||
| indexed1 = stringIndexerModel1.transform(df) | ||
| stringIndexerModel2 = stringIndexer2.fit(indexed1) | ||
| indexed2 = stringIndexerModel2.transform(indexed1) | ||
|
|
||
| encoder = OneHotEncoderEstimator(inputCols=["categoryIndex1", "categoryIndex2"], | ||
| outputCols=["categoryVec1", "categoryVec2"]) | ||
| model = encoder.fit(indexed2) | ||
| encoded = model.transform(indexed2) | ||
| encoded.show() | ||
| # $example off$ | ||
|
|
||
| spark.stop() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| // scalastyle:off println | ||
| package org.apache.spark.examples.ml | ||
|
|
||
| // $example on$ | ||
| import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer} | ||
| // $example off$ | ||
| import org.apache.spark.sql.SparkSession | ||
|
|
||
| object OneHotEncoderEstimatorExample { | ||
| def main(args: Array[String]): Unit = { | ||
| val spark = SparkSession | ||
| .builder | ||
| .appName("OneHotEncoderEstimatorExample") | ||
| .getOrCreate() | ||
|
|
||
| // $example on$ | ||
| val df = spark.createDataFrame(Seq( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know the examples are re-creating the existing We could mention in the user guide that it is common to encode categorical features using
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok for me. As an example, it seems a bit lengthy because the two |
||
| (0, "a", "x"), | ||
| (1, "b", "y"), | ||
| (2, "c", "y"), | ||
| (3, "a", "z"), | ||
| (4, "a", "y"), | ||
| (5, "c", "z") | ||
| )).toDF("id", "category1", "category2") | ||
|
|
||
| // TODO: Replace this with multi-column API of StringIndexer once SPARK-11215 is merged. | ||
| val indexer1 = new StringIndexer() | ||
| .setInputCol("category1") | ||
| .setOutputCol("categoryIndex1") | ||
| .fit(df) | ||
| val indexer2 = new StringIndexer() | ||
| .setInputCol("category2") | ||
| .setOutputCol("categoryIndex2") | ||
| .fit(df) | ||
| val indexed1 = indexer1.transform(df) | ||
| val indexed2 = indexer2.transform(indexed1) | ||
|
|
||
| val encoder = new OneHotEncoderEstimator() | ||
| .setInputCols(Array("categoryIndex1", "categoryIndex2")) | ||
| .setOutputCols(Array("categoryVec1", "categoryVec2")) | ||
| val model = encoder.fit(indexed2) | ||
|
|
||
| val encoded = model.transform(indexed2) | ||
| encoded.show() | ||
| // $example off$ | ||
|
|
||
| spark.stop() | ||
| } | ||
| } | ||
| // scalastyle:on println | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add a little more detail about why it's deprecated.
The reason is that because the existing
OneHotEncoderis a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a newOneHotEncoderEstimatorwas created that produces aOneHotEncoderModelwhen fit. Add a link to the JIRA ticket for more detail (https://issues.apache.org/jira/browse/SPARK-13030).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Added.