-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23578][ML] Add multicolumn support for Binarizer #20732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,18 +24,20 @@ import org.apache.spark.ml.Transformer | |
| import org.apache.spark.ml.attribute.BinaryAttribute | ||
| import org.apache.spark.ml.linalg._ | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.expressions.UserDefinedFunction | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
| * Binarize a column of continuous features given a threshold. | ||
| */ | ||
| @Since("1.4.0") | ||
| final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
| extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { | ||
| final class Binarizer @Since("1.4.0")(@Since("1.4.0") override val uid: String) | ||
| extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols | ||
| with DefaultParamsWritable { | ||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("binarizer")) | ||
|
|
@@ -45,66 +47,117 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
| * The features greater than the threshold, will be binarized to 1.0. | ||
| * The features equal to or less than the threshold, will be binarized to 0.0. | ||
| * Default: 0.0 | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("1.4.0") | ||
| val threshold: DoubleParam = | ||
| new DoubleParam(this, "threshold", "threshold used to binarize continuous features") | ||
| new DoubleParam(this, "threshold", "threshold used to binarize continuous features") | ||
|
|
||
| /** @group param */ | ||
| @Since("2.3.1") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| val thresholds: DoubleArrayParam = | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about extending |
||
| new DoubleArrayParam(this, "thresholds", "thresholds used to binarize continuous features") | ||
|
|
||
| /** @group getParam */ | ||
| @Since("1.4.0") | ||
| def getThreshold: Double = $(threshold) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.1") | ||
| def getThresholds: Array[Double] = $(thresholds) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.4.0") | ||
| def setThreshold(value: Double): this.type = set(threshold, value) | ||
|
|
||
| setDefault(threshold -> 0.0) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.1") | ||
| def setThresholds(value: Array[Double]): this.type = set(thresholds, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.4.0") | ||
| def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.1") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("1.4.0") | ||
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| @Since("2.3.1") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| @Since("2.3.1") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we dont need to add since annotation for a private method
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please refer to method I guess we do not need this method. |
||
| private[feature] def isBinarizerMultipleColumns(): Boolean = { | ||
| if (isSet(inputCols) && isSet(inputCol)) { | ||
| logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can not set both of them, according to current ML's convention. |
||
| "`Binarizer` only maps one column specified by `inputCol`") | ||
| false | ||
| } else if (isSet(inputCols)) { | ||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| val outputSchema = transformSchema(dataset.schema, logging = true) | ||
| val schema = dataset.schema | ||
| val inputType = schema($(inputCol)).dataType | ||
| val td = $(threshold) | ||
|
|
||
| val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } | ||
| val binarizerVector = udf { (data: Vector) => | ||
| val indices = ArrayBuilder.make[Int] | ||
| val values = ArrayBuilder.make[Double] | ||
|
|
||
| data.foreachActive { (index, value) => | ||
| if (value > td) { | ||
| indices += index | ||
| values += 1.0 | ||
|
|
||
| val (inputColName, outputColName, td) = if (isBinarizerMultipleColumns()) { | ||
| ($(inputCols).toSeq, $(outputCols).toSeq, $(thresholds).toSeq) | ||
| } | ||
| else { | ||
| (Seq($(inputCol)), Seq($(outputCol)), Seq($(threshold))) | ||
| } | ||
|
|
||
| val inputType = inputColName.map { col => schema(col).dataType } | ||
|
|
||
| val binarizerDouble: Seq[UserDefinedFunction] = td.map { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer not to create the val outputCols = inputColNames.zip(outputColNames).zip(thresholds).map{ case (inputColName, outputColName, threshold) =>
schema(inputColName).dataType match{
...
}
} |
||
| td => udf { (in: Double) => if (in > td) 1.0 else 0.0 } | ||
| } | ||
|
|
||
| val binarizerVector = td.map { td => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This processing UDF is updated for vectors with threshold<0. Please refer to the lastest master brance. |
||
| udf { (data: Vector) => | ||
| val indices = ArrayBuilder.make[Int] | ||
| val values = ArrayBuilder.make[Double] | ||
|
|
||
| data.foreachActive { (index, value) => | ||
| if (value > td) { | ||
| indices += index | ||
| values += 1.0 | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Vectors.sparse(data.size, indices.result(), values.result()).compressed | ||
| Vectors.sparse(data.size, indices.result(), values.result()).compressed | ||
| } | ||
| } | ||
|
|
||
| val metadata = outputSchema($(outputCol)).metadata | ||
| val metadata = outputColName.map { col => | ||
| outputSchema(col).metadata | ||
| } | ||
|
|
||
| inputType match { | ||
| case DoubleType => | ||
| dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) | ||
| case _: VectorUDT => | ||
| dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) | ||
| val newCols = inputType.zip(inputColName).zip(td).zipWithIndex.map { | ||
| case (((inputType, inputColName), td), idx) => | ||
| inputType match { | ||
| case DoubleType => binarizerDouble(idx)(col(inputColName)) | ||
| case _ => binarizerVector(idx)(col(inputColName)) | ||
| } | ||
| } | ||
| dataset.withColumns(outputColName, newCols, metadata) | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| val inputType = schema($(inputCol)).dataType | ||
| val outputColName = $(outputCol) | ||
| @Since("2.3.1") | ||
| def validateSchema(schema: StructType, | ||
| inputColName: String, | ||
| outputColName: String): StructField = { | ||
| val inputType = schema(inputColName).dataType | ||
|
|
||
| val outCol: StructField = inputType match { | ||
| case DoubleType => | ||
|
|
@@ -118,7 +171,22 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) | |
| if (schema.fieldNames.contains(outputColName)) { | ||
| throw new IllegalArgumentException(s"Output column $outputColName already exists.") | ||
| } | ||
| StructType(schema.fields :+ outCol) | ||
| outCol | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| val (inputColName, outputColName) = if (isBinarizerMultipleColumns()) { | ||
| ($(inputCols), $(outputCols)) | ||
| } | ||
| else { | ||
| (Array($(inputCol)), Array($(outputCol))) | ||
| } | ||
|
|
||
| val outputField = for (i <- 0 until inputColName.length) yield { | ||
| validateSchema(schema, inputColName(i), outputColName(i)) | ||
| } | ||
| StructType(schema.fields ++ outputField.filter(_ != null)) | ||
| } | ||
|
|
||
| @Since("1.4.1") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,10 +27,20 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { | |
| import testImplicits._ | ||
|
|
||
| @transient var data: Array[Double] = _ | ||
| @transient var data2: Array[Double] = _ | ||
| @transient var expectedBinarizer1: Array[Double] = _ | ||
| @transient var expectedBinarizer2: Array[Double] = _ | ||
| @transient var expectedBinarizer3: Array[Double] = _ | ||
| @transient var expectedBinarizer4: Array[Double] = _ | ||
|
|
||
| override def beforeAll(): Unit = { | ||
| super.beforeAll() | ||
| data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) | ||
| data2 = Array(-0.1, 0.5, -0.2, 0.3, -0.8, -0.7, 0.1, 0.4) | ||
| expectedBinarizer1 = Array(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0) | ||
| expectedBinarizer2 = Array(0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0) | ||
| expectedBinarizer3 = Array(0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0) | ||
| expectedBinarizer4 = Array(0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) | ||
| } | ||
|
|
||
| test("params") { | ||
|
|
@@ -51,6 +61,23 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { | |
| } | ||
| } | ||
|
|
||
| test("multiple columns: Binarize continuous features with default parameter") { | ||
| val dataFrame: DataFrame = (0 until data.length).map { idx => | ||
| (data(idx), data2(idx), expectedBinarizer1(idx), expectedBinarizer2(idx)) | ||
| }.toDF("feature1", "feature2", "expected1", "expected2") | ||
|
|
||
| val binarizer: Binarizer = new Binarizer() | ||
| .setInputCols(Array("feature1", "feature2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setThresholds(Array(0.0, 0.0)) | ||
|
|
||
| binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). | ||
| foreach { | ||
| case Row(r1: Double, e1: Double, r2: Double, e2: Double ) => assert(r1 == e1 && r2 == e2, | ||
| "The feature value is not correct after binarization.") | ||
| } | ||
| } | ||
|
|
||
| test("Binarize continuous features with setter") { | ||
| val threshold: Double = 0.2 | ||
| val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) | ||
|
|
@@ -67,6 +94,23 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { | |
| } | ||
| } | ||
|
|
||
| test("multiple columns:Binarize continuous features with setter") { | ||
| val dataFrame: DataFrame = (0 until data.length).map { idx => | ||
| (data(idx), data2(idx), expectedBinarizer3(idx), expectedBinarizer4(idx)) | ||
| }.toDF("feature1", "feature2", "expected1", "expected2") | ||
|
|
||
| val binarizer: Binarizer = new Binarizer() | ||
| .setInputCols(Array("feature1", "feature2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setThresholds(Array(0.2, 0.2)) | ||
|
|
||
| binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). | ||
| foreach { | ||
| case Row(r1: Double, e1: Double, r2: Double, e2: Double ) => assert(r1 == e1 && r2 == e2, | ||
| "The feature value is not correct after binarization.") | ||
| } | ||
| } | ||
|
|
||
| test("Binarize vector of continuous features with default parameter") { | ||
| val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) | ||
| val dataFrame: DataFrame = Seq( | ||
|
|
@@ -83,6 +127,24 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { | |
| } | ||
| } | ||
|
|
||
| test("multiple column: Binarize vector of continuous features with default parameter") { | ||
| val dataFrame: DataFrame = Seq( | ||
| (Vectors.dense(data), Vectors.dense(data2), | ||
| Vectors.dense(expectedBinarizer1), Vectors.dense(expectedBinarizer2)) | ||
| ).toDF("feature1", "feature2", "expected1", "expected2") | ||
|
|
||
| val binarizer: Binarizer = new Binarizer() | ||
| .setInputCols(Array("feature1", "feature2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setThresholds(Array(0.0, 0.0)) | ||
|
|
||
| binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). | ||
| foreach { | ||
| case Row(r1: Vector, e1: Vector, r2: Vector, e2: Vector ) => assert(r1 == e1 && r2 == e2, | ||
| "The feature value is not correct after binarization.") | ||
| } | ||
| } | ||
|
|
||
| test("Binarize vector of continuous features with setter") { | ||
| val threshold: Double = 0.2 | ||
| val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) | ||
|
|
@@ -101,6 +163,23 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { | |
| } | ||
| } | ||
|
|
||
| test("multiple column: Binarize vector of continuous features with setter") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think one suite with both double and vector columns is enough. |
||
| val dataFrame: DataFrame = Seq( | ||
| (Vectors.dense(data), Vectors.dense(data2), | ||
| Vectors.dense(expectedBinarizer3), Vectors.dense(expectedBinarizer4)) | ||
| ).toDF("feature1", "feature2", "expected1", "expected2") | ||
|
|
||
| val binarizer: Binarizer = new Binarizer() | ||
| .setInputCols(Array("feature1", "feature2")) | ||
| .setOutputCols(Array("result1", "result2")) | ||
| .setThresholds(Array(0.2, 0.2)) | ||
|
|
||
| binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). | ||
| foreach { | ||
| case Row(r1: Vector, e1: Vector, r2: Vector, e2: Vector ) => assert(r1 == e1 && r2 == e2, | ||
| "The feature value is not correct after binarization.") | ||
| } | ||
| } | ||
|
|
||
| test("read/write") { | ||
| val t = new Binarizer() | ||
|
|
@@ -110,3 +189,5 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { | |
| testDefaultReadWrite(t) | ||
| } | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about adding a vector column?