-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23048][ML] Add OneHotEncoderEstimator document and examples #20257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
21cb7d3
13a7b90
262c046
e57d9ee
18cf226
3c697bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -775,35 +775,43 @@ for more details on the API. | |
| </div> | ||
| </div> | ||
|
|
||
| ## OneHotEncoder | ||
| ## OneHotEncoder (Deprecated since 2.3.0) | ||
|
|
||
| [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. | ||
| Because this existing `OneHotEncoder` is a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a new `OneHotEncoderEstimator` was created that produces an `OneHotEncoderModel` when fitting. For more detail, please see [SPARK-13030](https://issues.apache.org/jira/browse/SPARK-13030). | ||
|
|
||
| `OneHotEncoder` has been deprecated in 2.3.0 and will be removed in 3.0.0. Please use [OneHotEncoderEstimator](ml-features.html#onehotencoderestimator) instead. | ||
|
|
||
| ## OneHotEncoderEstimator | ||
|
|
||
| [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a categorical feature, represented as a label index, to a binary vector with at most a single one-value indicating the presence of a specific feature value from among the set of all feature values. | ||
|
||
|
|
||
| `OneHotEncoderEstimator` can transform multiple columns, returning an one-hot-encoded output vector column for each input column. It is common to merge these vectors into a single feature vector using `VectorAssembler`. | ||
|
||
|
|
||
| `OneHotEncoderEstimator` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). | ||
|
|
||
| **Examples** | ||
|
|
||
| <div class="codetabs"> | ||
| <div data-lang="scala" markdown="1"> | ||
|
|
||
| Refer to the [OneHotEncoder Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoder) | ||
| for more details on the API. | ||
| Refer to the [OneHotEncoderEstimator Scala docs](api/scala/index.html#org.apache.spark.ml.feature.OneHotEncoderEstimator) for more details on the API. | ||
|
|
||
| {% include_example scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala %} | ||
| {% include_example scala/org/apache/spark/examples/ml/OneHotEncoderEstimatorExample.scala %} | ||
| </div> | ||
|
|
||
| <div data-lang="java" markdown="1"> | ||
|
|
||
| Refer to the [OneHotEncoder Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoder.html) | ||
| Refer to the [OneHotEncoderEstimator Java docs](api/java/org/apache/spark/ml/feature/OneHotEncoderEstimator.html) | ||
| for more details on the API. | ||
|
|
||
| {% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderExample.java %} | ||
| {% include_example java/org/apache/spark/examples/ml/JavaOneHotEncoderEstimatorExample.java %} | ||
| </div> | ||
|
|
||
| <div data-lang="python" markdown="1"> | ||
|
|
||
| Refer to the [OneHotEncoder Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoder) | ||
| for more details on the API. | ||
| Refer to the [OneHotEncoderEstimator Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.OneHotEncoderEstimator) for more details on the API. | ||
|
|
||
| {% include_example python/ml/onehot_encoder_example.py %} | ||
| {% include_example python/ml/onehot_encoder_estimator_example.py %} | ||
| </div> | ||
| </div> | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,9 +23,8 @@ | |
| import java.util.Arrays; | ||
| import java.util.List; | ||
|
|
||
| import org.apache.spark.ml.feature.OneHotEncoder; | ||
| import org.apache.spark.ml.feature.StringIndexer; | ||
| import org.apache.spark.ml.feature.StringIndexerModel; | ||
| import org.apache.spark.ml.feature.OneHotEncoderEstimator; | ||
| import org.apache.spark.ml.feature.OneHotEncoderModel; | ||
| import org.apache.spark.sql.Dataset; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.RowFactory; | ||
|
|
@@ -35,41 +34,37 @@ | |
| import org.apache.spark.sql.types.StructType; | ||
| // $example off$ | ||
|
|
||
| public class JavaOneHotEncoderExample { | ||
| public class JavaOneHotEncoderEstimatorExample { | ||
| public static void main(String[] args) { | ||
| SparkSession spark = SparkSession | ||
| .builder() | ||
| .appName("JavaOneHotEncoderExample") | ||
| .appName("JavaOneHotEncoderEstimatorExample") | ||
| .getOrCreate(); | ||
|
|
||
| // Note: categorical features are usually first encoded with StringIndexer | ||
| // $example on$ | ||
| List<Row> data = Arrays.asList( | ||
| RowFactory.create(0, "a"), | ||
| RowFactory.create(1, "b"), | ||
| RowFactory.create(2, "c"), | ||
| RowFactory.create(3, "a"), | ||
| RowFactory.create(4, "a"), | ||
| RowFactory.create(5, "c") | ||
| RowFactory.create(0.0, 1.0), | ||
| RowFactory.create(1.0, 0.0), | ||
| RowFactory.create(2.0, 1.0), | ||
| RowFactory.create(0.0, 2.0), | ||
| RowFactory.create(0.0, 1.0), | ||
| RowFactory.create(2.0, 0.0) | ||
| ); | ||
|
|
||
| StructType schema = new StructType(new StructField[]{ | ||
| new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), | ||
| new StructField("category", DataTypes.StringType, false, Metadata.empty()) | ||
| new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), | ||
| new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't need to pass
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is java example, the default param seems don't work: error: no suitable constructor found for StructField(String,DataType,boolean)
[error] new StructField("categoryIndex1", DataTypes.DoubleType, false),
[error] ^
[error] /root/repos/spark-1/constructor StructField.StructField(String,DataType,boolean,Metadata) is not applicable
[error] (actual and formal argument lists differ in length)
[error] constructor StructField.StructField() is not applicable |
||
| }); | ||
|
|
||
| Dataset<Row> df = spark.createDataFrame(data, schema); | ||
|
|
||
| StringIndexerModel indexer = new StringIndexer() | ||
| .setInputCol("category") | ||
| .setOutputCol("categoryIndex") | ||
| .fit(df); | ||
| Dataset<Row> indexed = indexer.transform(df); | ||
| OneHotEncoderEstimator encoder = new OneHotEncoderEstimator() | ||
| .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) | ||
| .setOutputCols(new String[] {"categoryVec1", "categoryVec2"}); | ||
|
|
||
| OneHotEncoder encoder = new OneHotEncoder() | ||
| .setInputCol("categoryIndex") | ||
| .setOutputCol("categoryVec"); | ||
|
|
||
| Dataset<Row> encoded = encoder.transform(indexed); | ||
| OneHotEncoderModel model = encoder.fit(df); | ||
| Dataset<Row> encoded = model.transform(df); | ||
| encoded.show(); | ||
| // $example off$ | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,38 +19,34 @@ | |
| package org.apache.spark.examples.ml | ||
|
|
||
| // $example on$ | ||
| import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} | ||
| import org.apache.spark.ml.feature.OneHotEncoderEstimator | ||
| // $example off$ | ||
| import org.apache.spark.sql.SparkSession | ||
|
|
||
| object OneHotEncoderExample { | ||
| object OneHotEncoderEstimatorExample { | ||
| def main(args: Array[String]): Unit = { | ||
| val spark = SparkSession | ||
| .builder | ||
| .appName("OneHotEncoderExample") | ||
| .appName("OneHotEncoderEstimatorExample") | ||
| .getOrCreate() | ||
|
|
||
| // Note: categorical features are usually first encoded with StringIndexer | ||
| // $example on$ | ||
| val df = spark.createDataFrame(Seq( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know the examples are re-creating the existing We could mention in the user guide that it is common to encode categorical features using
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok for me. As an example, it seems a bit lengthy because the two |
||
| (0, "a"), | ||
| (1, "b"), | ||
| (2, "c"), | ||
| (3, "a"), | ||
| (4, "a"), | ||
| (5, "c") | ||
| )).toDF("id", "category") | ||
|
|
||
| val indexer = new StringIndexer() | ||
| .setInputCol("category") | ||
| .setOutputCol("categoryIndex") | ||
| .fit(df) | ||
| val indexed = indexer.transform(df) | ||
|
|
||
| val encoder = new OneHotEncoder() | ||
| .setInputCol("categoryIndex") | ||
| .setOutputCol("categoryVec") | ||
|
|
||
| val encoded = encoder.transform(indexed) | ||
| (0.0, 1.0), | ||
| (1.0, 0.0), | ||
| (2.0, 1.0), | ||
| (0.0, 2.0), | ||
| (0.0, 1.0), | ||
| (2.0, 0.0) | ||
| )).toDF("categoryIndex1", "categoryIndex2") | ||
|
|
||
| val encoder = new OneHotEncoderEstimator() | ||
| .setInputCols(Array("categoryIndex1", "categoryIndex2")) | ||
| .setOutputCols(Array("categoryVec1", "categoryVec2")) | ||
| val model = encoder.fit(df) | ||
|
|
||
| val encoded = model.transform(df) | ||
| encoded.show() | ||
| // $example off$ | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add a little more detail about why it's deprecated.
The reason is that because the existing
OneHotEncoderis a stateless transformer, it is not usable on new data where the number of categories may differ from the training data. In order to fix this, a newOneHotEncoderEstimatorwas created that produces aOneHotEncoderModelwhen fit. Add a link to the JIRA ticket for more detail (https://issues.apache.org/jira/browse/SPARK-13030).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Added.