diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala index c2852aacb05d..9c56a69871ee 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -45,7 +45,23 @@ object BinarizerExample { binarizedDataFrame.show() // $example off$ + // $example on$ + val data2 = Array((0, 0.1), (1, 0.8), (2, 0.2)) + val dataFrame2 = spark.createDataFrame((0 until data.length).map { idx => + (data(idx), data2(idx)) + }).toDF("feature1", "feature2") + + val binarizer2: Binarizer = new Binarizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setThresholds(Array(0.5, 0.5)) + + val binarizedDataFrame2 = binarizer2.transform(dataFrame2) + binarizedDataFrame2.show() + // $example off$ + spark.stop() } } + // scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2b0862c60fdf..084d2bc15181 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -24,9 +24,10 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -34,8 +35,9 @@ import org.apache.spark.sql.types._ * Binarize a column of continuous features given a threshold. */ @Since("1.4.0") -final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) - extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { +final class Binarizer @Since("1.4.0")(@Since("1.4.0") override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols + with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("binarizer")) @@ -45,66 +47,117 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) * The features greater than the threshold, will be binarized to 1.0. * The features equal to or less than the threshold, will be binarized to 0.0. * Default: 0.0 + * * @group param */ @Since("1.4.0") val threshold: DoubleParam = - new DoubleParam(this, "threshold", "threshold used to binarize continuous features") + new DoubleParam(this, "threshold", "threshold used to binarize continuous features") + + /** @group param */ + @Since("2.3.1") + val thresholds: DoubleArrayParam = + new DoubleArrayParam(this, "thresholds", "thresholds used to binarize continuous features") /** @group getParam */ @Since("1.4.0") def getThreshold: Double = $(threshold) + /** @group getParam */ + @Since("2.3.1") + def getThresholds: Array[Double] = $(thresholds) + /** @group setParam */ @Since("1.4.0") def setThreshold(value: Double): this.type = set(threshold, value) setDefault(threshold -> 0.0) + /** @group setParam */ + @Since("2.3.1") + def setThresholds(value: Array[Double]): this.type = set(thresholds, value) + /** @group setParam */ @Since("1.4.0") def setInputCol(value: String): this.type = set(inputCol, value) + /** @group setParam */ + @Since("2.3.1") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + /** @group setParam */ @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + @Since("2.3.1") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + @Since("2.3.1") + private[feature] def isBinarizerMultipleColumns(): Boolean = { + if (isSet(inputCols) && isSet(inputCol)) { + logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + + "`Binarizer` only maps one column specified by `inputCol`") + false + } else if (isSet(inputCols)) { + true + } else { + false + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema - val inputType = schema($(inputCol)).dataType - val td = $(threshold) - - val binarizerDouble = udf { in: Double => if (in > td) 1.0 else 0.0 } - val binarizerVector = udf { (data: Vector) => - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - - data.foreachActive { (index, value) => - if (value > td) { - indices += index - values += 1.0 + + val (inputColName, outputColName, td) = if (isBinarizerMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq, $(thresholds).toSeq) + } + else { + (Seq($(inputCol)), Seq($(outputCol)), Seq($(threshold))) + } + + val inputType = inputColName.map { col => schema(col).dataType } + + val binarizerDouble: Seq[UserDefinedFunction] = td.map { + td => udf { (in: Double) => if (in > td) 1.0 else 0.0 } + } + + val binarizerVector = td.map { td => + udf { (data: Vector) => + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + + data.foreachActive { (index, value) => + if (value > td) { + indices += index + values += 1.0 + } } - } - Vectors.sparse(data.size, indices.result(), values.result()).compressed + Vectors.sparse(data.size, indices.result(), values.result()).compressed + } } - val metadata = outputSchema($(outputCol)).metadata + val metadata = outputColName.map { col => + outputSchema(col).metadata + } - inputType match { - case DoubleType => - dataset.select(col("*"), binarizerDouble(col($(inputCol))).as($(outputCol), metadata)) - case _: VectorUDT => - dataset.select(col("*"), binarizerVector(col($(inputCol))).as($(outputCol), metadata)) + val newCols = inputType.zip(inputColName).zip(td).zipWithIndex.map { + case (((inputType, inputColName), td), idx) => + inputType match { + case DoubleType => binarizerDouble(idx)(col(inputColName)) + case _ => binarizerVector(idx)(col(inputColName)) + } } + dataset.withColumns(outputColName, newCols, metadata) } - @Since("1.4.0") - override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - val outputColName = $(outputCol) + @Since("2.3.1") + def validateSchema(schema: StructType, + inputColName: String, + outputColName: String): StructField = { + val inputType = schema(inputColName).dataType val outCol: StructField = inputType match { case DoubleType => @@ -118,7 +171,22 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) if (schema.fieldNames.contains(outputColName)) { throw new IllegalArgumentException(s"Output column $outputColName already exists.") } - StructType(schema.fields :+ outCol) + outCol + } + + @Since("1.4.0") + override def transformSchema(schema: StructType): StructType = { + val (inputColName, outputColName) = if (isBinarizerMultipleColumns()) { + ($(inputCols), $(outputCols)) + } + else { + (Array($(inputCol)), Array($(outputCol))) + } + + val outputField = for (i <- 0 until inputColName.length) yield { + validateSchema(schema, inputColName(i), outputColName(i)) + } + StructType(schema.fields ++ outputField.filter(_ != null)) } @Since("1.4.1") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 05d4a6ee2dab..1c000cbfc5a5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -27,10 +27,20 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @transient var data: Array[Double] = _ + @transient var data2: Array[Double] = _ + @transient var expectedBinarizer1: Array[Double] = _ + @transient var expectedBinarizer2: Array[Double] = _ + @transient var expectedBinarizer3: Array[Double] = _ + @transient var expectedBinarizer4: Array[Double] = _ override def beforeAll(): Unit = { super.beforeAll() data = Array(0.1, -0.5, 0.2, -0.3, 0.8, 0.7, -0.1, -0.4) + data2 = Array(-0.1, 0.5, -0.2, 0.3, -0.8, -0.7, 0.1, 0.4) + expectedBinarizer1 = Array(1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0) + expectedBinarizer2 = Array(0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0) + expectedBinarizer3 = Array(0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0) + expectedBinarizer4 = Array(0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) } test("params") { @@ -51,6 +61,23 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } + test("multiple columns: Binarize continuous features with default parameter") { + val dataFrame: DataFrame = (0 until data.length).map { idx => + (data(idx), data2(idx), expectedBinarizer1(idx), expectedBinarizer2(idx)) + }.toDF("feature1", "feature2", "expected1", "expected2") + + val binarizer: Binarizer = new Binarizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setThresholds(Array(0.0, 0.0)) + + binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). + foreach { + case Row(r1: Double, e1: Double, r2: Double, e2: Double ) => assert(r1 == e1 && r2 == e2, + "The feature value is not correct after binarization.") + } + } + test("Binarize continuous features with setter") { val threshold: Double = 0.2 val thresholdBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) @@ -67,6 +94,23 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } + test("multiple columns:Binarize continuous features with setter") { + val dataFrame: DataFrame = (0 until data.length).map { idx => + (data(idx), data2(idx), expectedBinarizer3(idx), expectedBinarizer4(idx)) + }.toDF("feature1", "feature2", "expected1", "expected2") + + val binarizer: Binarizer = new Binarizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setThresholds(Array(0.2, 0.2)) + + binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). + foreach { + case Row(r1: Double, e1: Double, r2: Double, e2: Double ) => assert(r1 == e1 && r2 == e2, + "The feature value is not correct after binarization.") + } + } + test("Binarize vector of continuous features with default parameter") { val defaultBinarized: Array[Double] = data.map(x => if (x > 0.0) 1.0 else 0.0) val dataFrame: DataFrame = Seq( @@ -83,6 +127,24 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } + test("multiple column: Binarize vector of continuous features with default parameter") { + val dataFrame: DataFrame = Seq( + (Vectors.dense(data), Vectors.dense(data2), + Vectors.dense(expectedBinarizer1), Vectors.dense(expectedBinarizer2)) + ).toDF("feature1", "feature2", "expected1", "expected2") + + val binarizer: Binarizer = new Binarizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setThresholds(Array(0.0, 0.0)) + + binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). + foreach { + case Row(r1: Vector, e1: Vector, r2: Vector, e2: Vector ) => assert(r1 == e1 && r2 == e2, + "The feature value is not correct after binarization.") + } + } + test("Binarize vector of continuous features with setter") { val threshold: Double = 0.2 val defaultBinarized: Array[Double] = data.map(x => if (x > threshold) 1.0 else 0.0) @@ -101,6 +163,23 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { } } + test("multiple column: Binarize vector of continuous features with setter") { + val dataFrame: DataFrame = Seq( + (Vectors.dense(data), Vectors.dense(data2), + Vectors.dense(expectedBinarizer3), Vectors.dense(expectedBinarizer4)) + ).toDF("feature1", "feature2", "expected1", "expected2") + + val binarizer: Binarizer = new Binarizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setThresholds(Array(0.2, 0.2)) + + binarizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2").collect(). + foreach { + case Row(r1: Vector, e1: Vector, r2: Vector, e2: Vector ) => assert(r1 == e1 && r2 == e2, + "The feature value is not correct after binarization.") + } + } test("read/write") { val t = new Binarizer() @@ -110,3 +189,5 @@ class BinarizerSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(t) } } + +