diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index ec9792cbbda8f..459994c352da9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoder and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid - with HasInputCols with HasOutputCols { + with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols { /** * Param for how to handle invalid data during transform(). @@ -68,12 +68,21 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid @Since("2.3.0") def getDropLast: Boolean = $(dropLast) + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + ($(inputCols), $(outputCols)) + } + } + protected def validateAndTransformSchema( schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { - val inputColNames = $(inputCols) - val outputColNames = $(outputCols) + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols)) + val (inputColNames, outputColNames) = getInOutCols() require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -83,7 +92,7 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) // Prepares output columns with proper attributes by examining input columns. - val inputFields = $(inputCols).map(schema(_)) + val inputFields = inputColNames.map(schema(_)) val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => OneHotEncoderCommon.transformOutputColumnSchema( @@ -123,6 +132,14 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) @Since("3.0.0") def this() = this(Identifiable.randomUID("oneHotEncoder")) + /** @group setParam */ + @Since("3.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("3.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ @Since("3.0.0") def setInputCols(values: Array[String]): this.type = set(inputCols, values) @@ -150,13 +167,14 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) override def fit(dataset: Dataset[_]): OneHotEncoderModel = { transformSchema(dataset.schema) + val (inputColumns, outputColumns) = getInOutCols() // Compute the plain number of categories without `handleInvalid` and // `dropLast` taken into account. val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, keepInvalid = false) - val categorySizes = new Array[Int]($(outputCols).length) + val categorySizes = new Array[Int](outputColumns.length) - val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => + val columnToScanIndices = outputColumns.zipWithIndex.flatMap { case (outputColName, idx) => val numOfAttrs = AttributeGroup.fromStructField( transformedSchema(outputColName)).size if (numOfAttrs < 0) { @@ -170,8 +188,8 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) // Some input columns don't have attributes or their attributes don't have necessary info. // We need to scan the data to get the number of values for each column. if (columnToScanIndices.length > 0) { - val inputColNames = columnToScanIndices.map($(inputCols)(_)) - val outputColNames = columnToScanIndices.map($(outputCols)(_)) + val inputColNames = columnToScanIndices.map(inputColumns(_)) + val outputColNames = columnToScanIndices.map(outputColumns(_)) // When fitting data, we want the plain number of categories without `handleInvalid` and // `dropLast` taken into account. @@ -287,7 +305,7 @@ class OneHotEncoderModel private[ml] ( @Since("3.0.0") override def transformSchema(schema: StructType): StructType = { - val inputColNames = $(inputCols) + val (inputColNames, _) = getInOutCols() require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -306,8 +324,9 @@ class OneHotEncoderModel private[ml] ( */ private def verifyNumOfValues(schema: StructType): StructType = { val configedSizes = getConfigedCategorySizes - $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => - val inputColName = $(inputCols)(idx) + val (inputColNames, outputColNames) = getInOutCols() + outputColNames.zipWithIndex.foreach { case (outputColName, idx) => + val inputColName = inputColNames(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) // If the input metadata specifies number of category for output column, @@ -327,10 +346,11 @@ class OneHotEncoderModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID + val (inputColNames, outputColNames) = getInOutCols() - val encodedColumns = $(inputCols).indices.map { idx => - val inputColName = $(inputCols)(idx) - val outputColName = $(outputCols)(idx) + val encodedColumns = inputColNames.indices.map { idx => + val inputColName = inputColNames(idx) + val outputColName = outputColNames(idx) val outputAttrGroupFromSchema = AttributeGroup.fromStructField(transformedSchema(outputColName)) @@ -345,7 +365,7 @@ class OneHotEncoderModel private[ml] ( encoder(col(inputColName).cast(DoubleType), lit(idx)) .as(outputColName, metadata) } - dataset.withColumns($(outputCols), encodedColumns) + dataset.withColumns(outputColNames, encodedColumns) } @Since("3.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 70f8c029a2575..897251d9815c8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite @@ -62,6 +63,34 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: OneHotEncoder dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoder() + .setInputCol("input") + .setOutputCol("output") + assert(encoder.getDropLast) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + val model = encoder.fit(df) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } + } + test("OneHotEncoder dropLast = true") { val data = Seq( Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), @@ -104,6 +133,22 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val model = encoder.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } + } + test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() @@ -125,6 +170,13 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(encoder) } + test("Single Column: read/write") { + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + testDefaultReadWrite(encoder) + } + test("OneHotEncoderModel read/write") { val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) val newInstance = testDefaultReadWrite(instance) @@ -173,6 +225,48 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: OneHotEncoder with varying types") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoder() + .setInputCol("input") + .setOutputCol("output") + .setDropLast(false) + + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) + } + } + test("OneHotEncoder: encoding multiple columns and dropLast = false") { val data = Seq( Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), @@ -211,6 +305,58 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: OneHotEncoder: encoding multiple columns and dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder1 = new OneHotEncoder() + .setInputCol("input1") + .setOutputCol("output1") + assert(encoder1.getDropLast) + encoder1.setDropLast(false) + assert(encoder1.getDropLast === false) + + val model1 = encoder1.fit(df) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model1, + "output1", + "expected1") { + case Row(output1: Vector, expected1: Vector) => + assert(output1 === expected1) + } + + val encoder2 = new OneHotEncoder() + .setInputCol("input2") + .setOutputCol("output2") + assert(encoder2.getDropLast) + encoder2.setDropLast(false) + assert(encoder2.getDropLast === false) + + val model2 = encoder2.fit(df) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model2, + "output2", + "expected2") { + case Row(output2: Vector, expected2: Vector) => + assert(output2 === expected2) + } + } + test("OneHotEncoder: encoding multiple columns and dropLast = true") { val data = Seq( Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), @@ -419,4 +565,52 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", firstResultCol = "encoded") } + + test("assert exception is thrown if both multi-column and single-column params are set") { + import testImplicits._ + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("outputCols", Array("result1", "result2"))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("outputCol", "feature1")) + } + + test("Compare single/multiple column(s) OneHotEncoder in pipeline") { + val df = Seq((0.0, 2.0), (1.0, 3.0), (2.0, 0.0), (0.0, 1.0), (0.0, 0.0), (2.0, 2.0)) + .toDF("input1", "input2") + + val multiColsEncoder = new OneHotEncoder() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsEncoder)) + .fit(df) + + val encoderForCol1 = new OneHotEncoder() + .setInputCol("input1") + .setOutputCol("output1") + val encoderForCol2 = new OneHotEncoder() + .setInputCol("input2") + .setOutputCol("output2") + + val plForSingleCol = new Pipeline() + .setStages(Array(encoderForCol1, encoderForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("output1", "output2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("output1", "output2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle === rowForMultiCols) + } + } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 7645897ea5fc7..7ccdcf8560608 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2354,7 +2354,8 @@ def setOutputCol(self, value): return self._set(outputCol=value) -class _OneHotEncoderParams(HasInputCols, HasOutputCols, HasHandleInvalid): +class _OneHotEncoderParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, + HasHandleInvalid): """ Params for :py:class:`OneHotEncoder` and :py:class:`OneHotEncoderModel`. @@ -2415,6 +2416,10 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW 'error' >>> model.transform(df).head().output SparseVector(2, {0: 1.0}) + >>> single_col_ohe = OneHotEncoder(inputCol="input", outputCol="output") + >>> single_col_model = single_col_ohe.fit(df) + >>> single_col_model.transform(df).head().output + SparseVector(2, {0: 1.0}) >>> ohePath = temp_path + "/ohe" >>> ohe.save(ohePath) >>> loadedOHE = OneHotEncoder.load(ohePath) @@ -2430,9 +2435,11 @@ class OneHotEncoder(JavaEstimator, _OneHotEncoderParams, JavaMLReadable, JavaMLW """ @keyword_only - def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, + inputCol=None, outputCol=None): """ - __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \ + inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj( @@ -2443,9 +2450,11 @@ def __init__(self, inputCols=None, outputCols=None, handleInvalid="error", dropL @keyword_only @since("2.3.0") - def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True): + def setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, + inputCol=None, outputCol=None): """ - setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True) + setParams(self, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \ + inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ kwargs = self._input_kwargs @@ -2479,6 +2488,20 @@ def setHandleInvalid(self, value): """ return self._set(handleInvalid=value) + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return OneHotEncoderModel(java_model)