Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.expressions.UserDefinedFunction
Expand All @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/** Private trait for params and common methods for OneHotEncoder and OneHotEncoderModel */
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
with HasInputCols with HasOutputCols {
with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols {

/**
* Param for how to handle invalid data during transform().
Expand Down Expand Up @@ -68,12 +68,21 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
@Since("2.3.0")
def getDropLast: Boolean = $(dropLast)

/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
(Array($(inputCol)), Array($(outputCol)))
} else {
($(inputCols), $(outputCols))
}
}

protected def validateAndTransformSchema(
schema: StructType,
dropLast: Boolean,
keepInvalid: Boolean): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols))
val (inputColNames, outputColNames) = getInOutCols()

require(inputColNames.length == outputColNames.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
Expand All @@ -83,7 +92,7 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
inputColNames.foreach(SchemaUtils.checkNumericType(schema, _))

// Prepares output columns with proper attributes by examining input columns.
val inputFields = $(inputCols).map(schema(_))
val inputFields = inputColNames.map(schema(_))

val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) =>
OneHotEncoderCommon.transformOutputColumnSchema(
Expand Down Expand Up @@ -123,6 +132,14 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String)
@Since("3.0.0")
def this() = this(Identifiable.randomUID("oneHotEncoder"))

/** @group setParam */
@Since("3.0.0")
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
@Since("3.0.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("3.0.0")
def setInputCols(values: Array[String]): this.type = set(inputCols, values)
Expand Down Expand Up @@ -150,13 +167,14 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String)
override def fit(dataset: Dataset[_]): OneHotEncoderModel = {
transformSchema(dataset.schema)

val (inputColumns, outputColumns) = getInOutCols()
// Compute the plain number of categories without `handleInvalid` and
// `dropLast` taken into account.
val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false,
keepInvalid = false)
val categorySizes = new Array[Int]($(outputCols).length)
val categorySizes = new Array[Int](outputColumns.length)

val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) =>
val columnToScanIndices = outputColumns.zipWithIndex.flatMap { case (outputColName, idx) =>
val numOfAttrs = AttributeGroup.fromStructField(
transformedSchema(outputColName)).size
if (numOfAttrs < 0) {
Expand All @@ -170,8 +188,8 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String)
// Some input columns don't have attributes or their attributes don't have necessary info.
// We need to scan the data to get the number of values for each column.
if (columnToScanIndices.length > 0) {
val inputColNames = columnToScanIndices.map($(inputCols)(_))
val outputColNames = columnToScanIndices.map($(outputCols)(_))
val inputColNames = columnToScanIndices.map(inputColumns(_))
val outputColNames = columnToScanIndices.map(outputColumns(_))

// When fitting data, we want the plain number of categories without `handleInvalid` and
// `dropLast` taken into account.
Expand Down Expand Up @@ -287,7 +305,7 @@ class OneHotEncoderModel private[ml] (

@Since("3.0.0")
override def transformSchema(schema: StructType): StructType = {
val inputColNames = $(inputCols)
val (inputColNames, _) = getInOutCols()

require(inputColNames.length == categorySizes.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
Expand All @@ -306,8 +324,9 @@ class OneHotEncoderModel private[ml] (
*/
private def verifyNumOfValues(schema: StructType): StructType = {
val configedSizes = getConfigedCategorySizes
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = $(inputCols)(idx)
val (inputColNames, outputColNames) = getInOutCols()
outputColNames.zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = inputColNames(idx)
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))

// If the input metadata specifies number of category for output column,
Expand All @@ -327,10 +346,11 @@ class OneHotEncoderModel private[ml] (
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID
val (inputColNames, outputColNames) = getInOutCols()

val encodedColumns = $(inputCols).indices.map { idx =>
val inputColName = $(inputCols)(idx)
val outputColName = $(outputCols)(idx)
val encodedColumns = inputColNames.indices.map { idx =>
val inputColName = inputColNames(idx)
val outputColName = outputColNames(idx)

val outputAttrGroupFromSchema =
AttributeGroup.fromStructField(transformedSchema(outputColName))
Expand All @@ -345,7 +365,7 @@ class OneHotEncoderModel private[ml] (
encoder(col(inputColName).cast(DoubleType), lit(idx))
.as(outputColName, metadata)
}
dataset.withColumns($(outputCols), encodedColumns)
dataset.withColumns(outputColNames, encodedColumns)
}

@Since("3.0.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.ml.feature

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param.ParamsSuite
Expand Down Expand Up @@ -62,6 +63,34 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest {
}
}

test("Single Column: OneHotEncoder dropLast = false") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))

val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))

val df = spark.createDataFrame(sc.parallelize(data), schema)

val encoder = new OneHotEncoder()
.setInputCol("input")
.setOutputCol("output")
assert(encoder.getDropLast)
encoder.setDropLast(false)
assert(encoder.getDropLast === false)
val model = encoder.fit(df)
testTransformer[(Double, Vector)](df, model, "output", "expected") {
case Row(output: Vector, expected: Vector) =>
assert(output === expected)
}
}

test("OneHotEncoder dropLast = true") {
val data = Seq(
Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))),
Expand Down Expand Up @@ -104,6 +133,22 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest {
}
}

test("Single Column: input column with ML attribute") {
val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large")
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size")
.select(col("size").as("size", attr.toMetadata()))
val encoder = new OneHotEncoder()
.setInputCol("size")
.setOutputCol("encoded")
val model = encoder.fit(df)
testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows =>
val group = AttributeGroup.fromStructField(rows.head.schema("encoded"))
assert(group.size === 2)
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1))
}
}

test("input column without ML attribute") {
val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index")
val encoder = new OneHotEncoder()
Expand All @@ -125,6 +170,13 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest {
testDefaultReadWrite(encoder)
}

test("Single Column: read/write") {
val encoder = new OneHotEncoder()
.setInputCol("index")
.setOutputCol("encoded")
testDefaultReadWrite(encoder)
}

test("OneHotEncoderModel read/write") {
val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3))
val newInstance = testDefaultReadWrite(instance)
Expand Down Expand Up @@ -173,6 +225,48 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest {
}
}

test("Single Column: OneHotEncoder with varying types") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))))

val schema = StructType(Array(
StructField("input", DoubleType),
StructField("expected", new VectorUDT)))

val df = spark.createDataFrame(sc.parallelize(data), schema)

class NumericTypeWithEncoder[A](val numericType: NumericType)
(implicit val encoder: Encoder[(A, Vector)])

val types = Seq(
new NumericTypeWithEncoder[Short](ShortType),
new NumericTypeWithEncoder[Long](LongType),
new NumericTypeWithEncoder[Int](IntegerType),
new NumericTypeWithEncoder[Float](FloatType),
new NumericTypeWithEncoder[Byte](ByteType),
new NumericTypeWithEncoder[Double](DoubleType),
new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()))

for (t <- types) {
val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected"))
val estimator = new OneHotEncoder()
.setInputCol("input")
.setOutputCol("output")
.setDropLast(false)

val model = estimator.fit(dfWithTypes)
testTransformer(dfWithTypes, model, "output", "expected") {
case Row(output: Vector, expected: Vector) =>
assert(output === expected)
}(t.encoder)
}
}

test("OneHotEncoder: encoding multiple columns and dropLast = false") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))),
Expand Down Expand Up @@ -211,6 +305,58 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest {
}
}

test("Single Column: OneHotEncoder: encoding multiple columns and dropLast = false") {
val data = Seq(
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))),
Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))),
Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))),
Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))))

val schema = StructType(Array(
StructField("input1", DoubleType),
StructField("expected1", new VectorUDT),
StructField("input2", DoubleType),
StructField("expected2", new VectorUDT)))

val df = spark.createDataFrame(sc.parallelize(data), schema)

val encoder1 = new OneHotEncoder()
.setInputCol("input1")
.setOutputCol("output1")
assert(encoder1.getDropLast)
encoder1.setDropLast(false)
assert(encoder1.getDropLast === false)

val model1 = encoder1.fit(df)
testTransformer[(Double, Vector, Double, Vector)](
df,
model1,
"output1",
"expected1") {
case Row(output1: Vector, expected1: Vector) =>
assert(output1 === expected1)
}

val encoder2 = new OneHotEncoder()
.setInputCol("input2")
.setOutputCol("output2")
assert(encoder2.getDropLast)
encoder2.setDropLast(false)
assert(encoder2.getDropLast === false)

val model2 = encoder2.fit(df)
testTransformer[(Double, Vector, Double, Vector)](
df,
model2,
"output2",
"expected2") {
case Row(output2: Vector, expected2: Vector) =>
assert(output2 === expected2)
}
}

test("OneHotEncoder: encoding multiple columns and dropLast = true") {
val data = Seq(
Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))),
Expand Down Expand Up @@ -419,4 +565,52 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest {
expectedMessagePart = "OneHotEncoderModel expected 2 categorical values",
firstResultCol = "encoded")
}

test("assert exception is thrown if both multi-column and single-column params are set") {
import testImplicits._
val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2")
ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("inputCol", "feature1"),
("inputCols", Array("feature1", "feature2")))
ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("inputCol", "feature1"),
("outputCol", "result1"), ("outputCols", Array("result1", "result2")))

// this should fail because at least one of inputCol and inputCols must be set
ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("outputCol", "feature1"))
}

test("Compare single/multiple column(s) OneHotEncoder in pipeline") {
val df = Seq((0.0, 2.0), (1.0, 3.0), (2.0, 0.0), (0.0, 1.0), (0.0, 0.0), (2.0, 2.0))
.toDF("input1", "input2")

val multiColsEncoder = new OneHotEncoder()
.setInputCols(Array("input1", "input2"))
.setOutputCols(Array("output1", "output2"))

val plForMultiCols = new Pipeline()
.setStages(Array(multiColsEncoder))
.fit(df)

val encoderForCol1 = new OneHotEncoder()
.setInputCol("input1")
.setOutputCol("output1")
val encoderForCol2 = new OneHotEncoder()
.setInputCol("input2")
.setOutputCol("output2")

val plForSingleCol = new Pipeline()
.setStages(Array(encoderForCol1, encoderForCol2))
.fit(df)

val resultForSingleCol = plForSingleCol.transform(df)
.select("output1", "output2")
.collect()
val resultForMultiCols = plForMultiCols.transform(df)
.select("output1", "output2")
.collect()

resultForSingleCol.zip(resultForMultiCols).foreach {
case (rowForSingle, rowForMultiCols) =>
assert(rowForSingle === rowForMultiCols)
}
}
}
Loading