diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 99c0a0df5367..5e593f731c62 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols} +import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -32,7 +32,8 @@ import org.apache.spark.sql.types._ /** * Params for [[Imputer]] and [[ImputerModel]]. */ -private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCols { +private[feature] trait ImputerParams extends Params with HasInputCol with HasInputCols + with HasOutputCol with HasOutputCols { /** * The imputation strategy. Currently only "mean" and "median" are supported. @@ -63,15 +64,26 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu /** @group getParam */ def getMissingValue: Double = $(missingValue) + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + ($(inputCols), $(outputCols)) + } + } + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + - s" duplicates: (${$(inputCols).mkString(", ")})") - require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + - s" duplicates: (${$(outputCols).mkString(", ")})") - require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + - s" and outputCols(${$(outputCols).length}) should have the same length") - val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols)) + val (inputColNames, outputColNames) = getInOutCols() + require(inputColNames.length == inputColNames.distinct.length, s"inputCols contains" + + s" duplicates: (${inputColNames.mkString(", ")})") + require(outputColNames.length == outputColNames.distinct.length, s"outputCols contains" + + s" duplicates: (${outputColNames.mkString(", ")})") + require(inputColNames.length == outputColNames.length, s"inputCols(${inputColNames.length})" + + s" and outputCols(${outputColNames.length}) should have the same length") + val outputFields = inputColNames.zip(outputColNames).map { case (inputCol, outputCol) => val inputField = schema(inputCol) SchemaUtils.checkNumericType(schema, inputCol) StructField(outputCol, inputField.dataType, inputField.nullable) @@ -101,6 +113,14 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) @Since("2.2.0") def this() = this(Identifiable.randomUID("imputer")) + /** @group setParam */ + @Since("3.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("3.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ @Since("2.2.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -126,7 +146,9 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) transformSchema(dataset.schema, logging = true) val spark = dataset.sparkSession - val cols = $(inputCols).map { inputCol => + val (inputColumns, _) = getInOutCols() + + val cols = inputColumns.map { inputCol => when(col(inputCol).equalTo($(missingValue)), null) .when(col(inputCol).isNaN, null) .otherwise(col(inputCol)) @@ -139,7 +161,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) // Function avg will ignore null automatically. // For a column only containing null, avg will return null. val row = dataset.select(cols.map(avg): _*).head() - Array.range(0, $(inputCols).length).map { i => + Array.range(0, inputColumns.length).map { i => if (row.isNullAt(i)) { Double.NaN } else { @@ -150,7 +172,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) case Imputer.median => // Function approxQuantile will ignore null automatically. // For a column only containing null, approxQuantile will return an empty array. - dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001) + dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), 0.001) .map { array => if (array.isEmpty) { Double.NaN @@ -160,7 +182,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) } } - val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1) + val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1) if (emptyCols.nonEmpty) { throw new SparkException(s"surrogate cannot be computed. " + s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " + @@ -168,7 +190,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) } val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results))) - val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) + val schema = StructType(inputColumns.map(col => StructField(col, DoubleType, nullable = false))) val surrogateDF = spark.createDataFrame(rows, schema) copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) } @@ -205,6 +227,14 @@ class ImputerModel private[ml] ( import ImputerModel._ + /** @group setParam */ + @Since("3.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("3.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -213,9 +243,11 @@ class ImputerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq + val (inputColumns, outputColumns) = getInOutCols + val surrogates = surrogateDF.select(inputColumns.map(col): _*).head().toSeq + - val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { + val newCols = inputColumns.zip(outputColumns).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType val ic = col(inputCol).cast(DoubleType) @@ -224,7 +256,7 @@ class ImputerModel private[ml] ( .otherwise(ic) .cast(inputType) } - dataset.withColumns($(outputCols), newCols).toDF() + dataset.withColumns(outputColumns, newCols).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 02ef261a6c06..dfee2b4029c8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -36,7 +38,31 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { val imputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("out1", "out2")) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer for Double with default missing Value NaN") { + val df1 = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 11.0, 11.0, 11.0), + (2, 3.0, 3.0, 3.0), + (3, Double.NaN, 5.0, 3.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer1 = new Imputer() + .setInputCol("value") + .setOutputCol("out") + ImputerSuite.iterateStrategyTest(false, imputer1, df1) + + val df2 = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 12.0, 12.0, 12.0), + (2, Double.NaN, 10.0, 12.0), + (3, 14.0, 14.0, 14.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer2 = new Imputer() + .setInputCol("value") + .setOutputCol("out") + ImputerSuite.iterateStrategyTest(false, imputer2, df2) } test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { @@ -48,7 +74,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer should handle NaNs when computing surrogate value," + + " if missingValue is not NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + .setMissingValue(-1.0) + ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer for Float with missing Value -1.0") { @@ -61,7 +100,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer for Float with missing Value -1.0") { + val df = spark.createDataFrame( Seq( + (0, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + .setMissingValue(-1) + ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer should impute null as well as 'missingValue'") { @@ -74,7 +126,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer should impute null as well as 'missingValue'") { + val rawDf = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0), + (4, -1.0, 8.0, 10.0) + )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer should work with Structured Streaming") { @@ -99,6 +164,28 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: Imputer should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + val df = Seq[(java.lang.Double, Double)]( + (4.0, 4.0), + (10.0, 10.0), + (10.0, 10.0), + (Double.NaN, 8.0), + (null, 8.0) + ).toDF("value", "expected_mean_value") + val imputer = new Imputer() + .setInputCol("value") + .setOutputCol("out") + .setStrategy("mean") + val model = imputer.fit(df) + testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") { + case Row(exp: java.lang.Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + test("Imputer throws exception when surrogate cannot be computed") { val df = spark.createDataFrame( Seq( (0, Double.NaN, 1.0, 1.0), @@ -117,6 +204,24 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: Imputer throws exception when surrogate cannot be computed") { + val df = spark.createDataFrame( Seq( + (0, Double.NaN, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + Seq("mean", "median").foreach { strategy => + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + .setStrategy(strategy) + withClue("Imputer should fail all the values are invalid") { + val e: SparkException = intercept[SparkException] { + val model = imputer.fit(df) + } + assert(e.getMessage.contains("surrogate cannot be computed")) + } + } + } + test("Imputer input & output column validation") { val df = spark.createDataFrame( Seq( (0, 1.0, 1.0, 1.0), @@ -164,6 +269,14 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(t) } + test("Single Column: Imputer read/write") { + val t = new Imputer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMissingValue(-1.0) + testDefaultReadWrite(t) + } + test("ImputerModel read/write") { val spark = this.spark import spark.implicits._ @@ -178,6 +291,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) } + test("Single Column: ImputerModel read/write") { + val spark = this.spark + import spark.implicits._ + val surrogateDF = Seq(1.234).toDF("myInputCol") + + val instance = new ImputerModel( + "myImputer", surrogateDF) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns) + assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) + } + test("Imputer for IntegerType with default missing value null") { val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( @@ -195,7 +322,27 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { for (mType <- types) { // cast all columns to desired data type for testing val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) - ImputerSuite.iterateStrategyTest(imputer, df2) + ImputerSuite.iterateStrategyTest(true, imputer, df2) + } + } + + test("Single Column Imputer for IntegerType with default missing value null") { + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (null, 5, 3) + )).toDF("value", "expected_mean_value", "expected_median_value") + + val imputer = new Imputer() + .setInputCol("value") + .setOutputCol("out") + + val types = Seq(IntegerType, LongType) + for (mType <- types) { + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(false, imputer, df2) } } @@ -217,7 +364,85 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { for (mType <- types) { // cast all columns to desired data type for testing val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) - ImputerSuite.iterateStrategyTest(imputer, df2) + ImputerSuite.iterateStrategyTest(true, imputer, df2) + } + } + + test("Single Column: Imputer for IntegerType with missing value -1") { + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (-1, 5, 3) + )).toDF("value", "expected_mean_value", "expected_median_value") + + val imputer = new Imputer() + .setInputCol("value") + .setOutputCol("out") + .setMissingValue(-1.0) + + val types = Seq(IntegerType, LongType) + for (mType <- types) { + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(false, imputer, df2) + } + } + + test("assert exception is thrown if both multi-column and single-column params are set") { + import testImplicits._ + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new Imputer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new Imputer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("outputCols", Array("result1", "result2"))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new Imputer, df, ("outputCol", "feature1")) + } + + test("Compare single/multiple column(s) Imputer in pipeline") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 4.0), + (1, 11.0, 12.0), + (2, 3.0, Double.NaN), + (3, Double.NaN, 14.0) + )).toDF("id", "value1", "value2") + Seq("mean", "median").foreach { strategy => + val multiColsImputer = new Imputer() + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("result1", "result2")) + .setStrategy(strategy) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsImputer)) + .fit(df) + + val imputerForCol1 = new Imputer() + .setInputCol("value1") + .setOutputCol("result1") + .setStrategy(strategy) + val imputerForCol2 = new Imputer() + .setInputCol("value2") + .setOutputCol("result2") + .setStrategy(strategy) + + val plForSingleCol = new Pipeline() + .setStages(Array(imputerForCol1, imputerForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("result1", "result2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("result1", "result2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && + rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1)) + } } } } @@ -228,34 +453,45 @@ object ImputerSuite { * Imputation strategy. Available options are ["mean", "median"]. * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" */ - def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { + def iterateStrategyTest(isMultiCol: Boolean, imputer: Imputer, df: DataFrame): Unit = { Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) val resultDF = model.transform(df) - imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => - - // check dataType is consistent between input and output - val inputType = resultDF.schema(inputCol).dataType - val outputType = resultDF.schema(outputCol).dataType - assert(inputType == outputType, "Output type is not the same as input type.") - - // check value - resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { - case Row(exp: Float, out: Float) => - assert((exp.isNaN && out.isNaN) || (exp == out), - s"Imputed values differ. Expected: $exp, actual: $out") - case Row(exp: Double, out: Double) => - assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), - s"Imputed values differ. Expected: $exp, actual: $out") - case Row(exp: Integer, out: Integer) => - assert(exp == out, - s"Imputed values differ. Expected: $exp, actual: $out") - case Row(exp: Long, out: Long) => - assert(exp == out, - s"Imputed values differ. Expected: $exp, actual: $out") + if (isMultiCol) { + imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + verifyTransformResult(strategy, inputCol, outputCol, resultDF) } + } else { + verifyTransformResult(strategy, imputer.getInputCol, imputer.getOutputCol, resultDF) } } } + + def verifyTransformResult( + strategy: String, + inputCol: String, + outputCol: String, + resultDF: DataFrame): Unit = { + // check dataType is consistent between input and output + val inputType = resultDF.schema(inputCol).dataType + val outputType = resultDF.schema(outputCol).dataType + assert(inputType == outputType, "Output type is not the same as input type.") + + // check value + resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { + case Row(exp: Float, out: Float) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Integer, out: Integer) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Long, out: Long) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") + } + } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 11bb7941b5d9..7645897ea5fc 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1464,7 +1464,7 @@ def numDocs(self): return self._call_java("numDocs") -class _ImputerParams(HasInputCols, HasOutputCols): +class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols): """ Params for :py:class:`Imputer` and :py:class:`ImputerModel`. @@ -1540,6 +1540,55 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable): +---+---+-----+-----+ |1.0|NaN| 4.0| NaN| ... + >>> df1 = spark.createDataFrame([(1.0,), (2.0,), (float("nan"),), (4.0,), (5.0,)], ["a"]) + >>> imputer1 = Imputer(inputCol="a", outputCol="out_a") + >>> model1 = imputer1.fit(df1) + >>> model1.surrogateDF.show() + +---+ + | a| + +---+ + |3.0| + +---+ + ... + >>> model1.transform(df1).show() + +---+-----+ + | a|out_a| + +---+-----+ + |1.0| 1.0| + |2.0| 2.0| + |NaN| 3.0| + ... + >>> imputer1.setStrategy("median").setMissingValue(1.0).fit(df1).transform(df1).show() + +---+-----+ + | a|out_a| + +---+-----+ + |1.0| 4.0| + ... + >>> df2 = spark.createDataFrame([(float("nan"),), (float("nan"),), (3.0,), (4.0,), (5.0,)], + ... ["b"]) + >>> imputer2 = Imputer(inputCol="b", outputCol="out_b") + >>> model2 = imputer2.fit(df2) + >>> model2.surrogateDF.show() + +---+ + | b| + +---+ + |4.0| + +---+ + ... + >>> model2.transform(df2).show() + +---+-----+ + | b|out_b| + +---+-----+ + |NaN| 4.0| + |NaN| 4.0| + |3.0| 3.0| + ... + >>> imputer2.setStrategy("median").setMissingValue(1.0).fit(df2).transform(df2).show() + +---+-----+ + | b|out_b| + +---+-----+ + |NaN| NaN| + ... >>> imputerPath = temp_path + "/imputer" >>> imputer.save(imputerPath) >>> loadedImputer = Imputer.load(imputerPath) @@ -1558,10 +1607,10 @@ class Imputer(JavaEstimator, _ImputerParams, JavaMLReadable, JavaMLWritable): @keyword_only def __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, - outputCols=None): + outputCols=None, inputCol=None, outputCol=None): """ __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ - outputCols=None): + outputCols=None, inputCol=None, outputCol=None): """ super(Imputer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Imputer", self.uid) @@ -1572,10 +1621,10 @@ def __init__(self, strategy="mean", missingValue=float("nan"), inputCols=None, @keyword_only @since("2.2.0") def setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, - outputCols=None): + outputCols=None, inputCol=None, outputCol=None): """ setParams(self, strategy="mean", missingValue=float("nan"), inputCols=None, \ - outputCols=None) + outputCols=None, inputCol=None, outputCol=None) Sets params for this Imputer. """ kwargs = self._input_kwargs @@ -1609,6 +1658,20 @@ def setOutputCols(self, value): """ return self._set(outputCols=value) + @since("3.0.0") + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("3.0.0") + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + def _create_model(self, java_model): return ImputerModel(java_model)