Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 49 additions & 17 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols}
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand All @@ -32,7 +32,8 @@ import org.apache.spark.sql.types._
/**
* Params for [[Imputer]] and [[ImputerModel]].
*/
private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCols {
private[feature] trait ImputerParams extends Params with HasInputCol with HasInputCols
with HasOutputCol with HasOutputCols {

/**
* The imputation strategy. Currently only "mean" and "median" are supported.
Expand Down Expand Up @@ -63,15 +64,26 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu
/** @group getParam */
def getMissingValue: Double = $(missingValue)

/** Returns the input and output column names corresponding in pair. */
private[feature] def getInOutCols(): (Array[String], Array[String]) = {
if (isSet(inputCol)) {
(Array($(inputCol)), Array($(outputCol)))
} else {
($(inputCols), $(outputCols))
}
}

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" +
s" duplicates: (${$(inputCols).mkString(", ")})")
require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" +
s" duplicates: (${$(outputCols).mkString(", ")})")
require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" +
s" and outputCols(${$(outputCols).length}) should have the same length")
val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) =>
ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols))
val (inputColNames, outputColNames) = getInOutCols()
require(inputColNames.length == inputColNames.distinct.length, s"inputCols contains" +
s" duplicates: (${inputColNames.mkString(", ")})")
require(outputColNames.length == outputColNames.distinct.length, s"outputCols contains" +
s" duplicates: (${outputColNames.mkString(", ")})")
require(inputColNames.length == outputColNames.length, s"inputCols(${inputColNames.length})" +
s" and outputCols(${outputColNames.length}) should have the same length")
val outputFields = inputColNames.zip(outputColNames).map { case (inputCol, outputCol) =>
val inputField = schema(inputCol)
SchemaUtils.checkNumericType(schema, inputCol)
StructField(outputCol, inputField.dataType, inputField.nullable)
Expand Down Expand Up @@ -101,6 +113,14 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
@Since("2.2.0")
def this() = this(Identifiable.randomUID("imputer"))

/** @group setParam */
@Since("3.0.0")
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
@Since("3.0.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.2.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
Expand All @@ -126,7 +146,9 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession

val cols = $(inputCols).map { inputCol =>
val (inputColumns, _) = getInOutCols()

val cols = inputColumns.map { inputCol =>
when(col(inputCol).equalTo($(missingValue)), null)
.when(col(inputCol).isNaN, null)
.otherwise(col(inputCol))
Expand All @@ -139,7 +161,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
// Function avg will ignore null automatically.
// For a column only containing null, avg will return null.
val row = dataset.select(cols.map(avg): _*).head()
Array.range(0, $(inputCols).length).map { i =>
Array.range(0, inputColumns.length).map { i =>
if (row.isNullAt(i)) {
Double.NaN
} else {
Expand All @@ -150,7 +172,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
case Imputer.median =>
// Function approxQuantile will ignore null automatically.
// For a column only containing null, approxQuantile will return an empty array.
dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001)
dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), 0.001)
.map { array =>
if (array.isEmpty) {
Double.NaN
Expand All @@ -160,15 +182,15 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
}
}

val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1)
val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1)
if (emptyCols.nonEmpty) {
throw new SparkException(s"surrogate cannot be computed. " +
s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " +
s"missingValue(${$(missingValue)})")
}

val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results)))
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
val schema = StructType(inputColumns.map(col => StructField(col, DoubleType, nullable = false)))
val surrogateDF = spark.createDataFrame(rows, schema)
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
}
Expand Down Expand Up @@ -205,6 +227,14 @@ class ImputerModel private[ml] (

import ImputerModel._

/** @group setParam */
@Since("3.0.0")
def setInputCol(value: String): this.type = set(inputCol, value)
Copy link
Member

@zero323 zero323 Jan 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the intended purpose of this method?

As it is implemented right now, it doesn't seem to have any practical applications:

  • If model has been created with single col, surrogate will contain only a single column, so there is nothing to set here.

  • If model has been created with multiple cols, setInputCol / setOutputCol should clear setInputCols and setOutputCols, otherwise it will fail to validate. I guess something like this:

    @Since("3.0.0")
    def setInputCol(value: String): this.type = {
      clear(inputCols)
      clear(outputCols)
      set(inputCol, value)
    }
    
    @Since("3.0.0")
    def setOutputCol(value: String): this.type = {
      clear(inputCols)
      clear(outputCols)
       set(outputCol, value)
    }
    

I am asking, because these two are missing in Python (#27195).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zero323 I actually realized I missed the two setters in python when I checked the parity between python and scala last night. I fixed it along with a few other problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zero323 There is a check on scala side to make sure only setInputCol/setOutputCol or setInputCols/setOutputCols is set

Copy link
Member

@zero323 zero323 Jan 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zero323 There is a check on scala side to make sure only setInputCol/setOutputCol or setInputCols/setOutputCols is set

That's is what confuses me. Let's say the workflow looks like this:

import org.apache.spark.ml.feature.Imputer

val df = Seq((1, 2)).toDF("x1", "x2")

val mm = new Imputer()
   .setInputCols(Array("x1", "x2"))
   .setOutputCols(Array("x1_", "x2_"))
   .fit(df)

You cannot switch to single col at the model level:

mm.setInputCol("x1").setOutputCol("x1_").transform(df)

// java.lang.IllegalArgumentException: requirement failed: ImputerModel ImputerModel: uid=imputer_5923f59d0d3a, strategy=mean, missingValue=NaN, numInputCols=2, numOutputCols=2 requires exactly one of inputCol, inputCols Params to be set, but both are set.

without clearing cols explicitly:

mm.clear(mm.inputCols).clear(mm.outputCols).transform(df)

That's really not intuitive workflow, if this is what was intended.

If we only want to support Imupter.setInputCol -> ImputerModel.setInputcol, then there is no point in having this method at all:

val ms = new Imputer().setInputCol("x1").setOutputCol("x1_").fit(df)

ms.setInputCol("x2").setOutputCol("x2_").transform(df)
// org.apache.spark.sql.AnalysisException: cannot resolve '`x2`' given input columns: [x1];;

as surrogate contains only the column used for fit

ms.surrogateDF
org.apache.spark.sql.DataFrame = [x1: double]

Do I miss something obvious here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zero323 It is a problem. I will have a follow up pr to fix this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the sanity check in Scala side for inputCol/outputCol and inputCols/outputCols are more for preventing errors when mixing single and multiple columns at the same time, e.g. set both single and multiple column params, inputCol + outputCols...etc.

It sounds rarely switching between single/multiple column during fitting and transforming.

Copy link
Member

@zero323 zero323 Jan 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya But then we're back to the question why we need setInputCol in Models. Should we support

new Estimator().setInputCols(...).fit(...).setInputCol(...).transform(...)

flow at all?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what about overriding inputCols (taking a subset?)

new Estimator().setInputCols(...).fit(...).setInputCols(...).transform(...)

for that matter?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a model fitted by an Estimator, I think we usually won't change input/output column(s). The setter is still useful, as there are still cases that we might create a model instance directly. For such cases, we need input column(s) setter.

Copy link
Member

@zero323 zero323 Jan 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having option to overwrite outputCol(s) on can be useful to avoid name clashes on pre-trained models.

But providing setters for inputs seems to be more confusing than useful, and proliferation of Params that support both Col and Cols makes things even more fuzzy, as there is no way to tell which variant we have, without inspecting Param values.

In general I am asking because we seem to have cases like OneHotEncoderModel - which provide setInputCols / setOutputCols but no single column equivalents.


/** @group setParam */
@Since("3.0.0")
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

Expand All @@ -213,9 +243,11 @@ class ImputerModel private[ml] (

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq
val (inputColumns, outputColumns) = getInOutCols
val surrogates = surrogateDF.select(inputColumns.map(col): _*).head().toSeq


val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map {
val newCols = inputColumns.zip(outputColumns).zip(surrogates).map {
case ((inputCol, outputCol), surrogate) =>
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol).cast(DoubleType)
Expand All @@ -224,7 +256,7 @@ class ImputerModel private[ml] (
.otherwise(ic)
.cast(inputType)
}
dataset.withColumns($(outputCols), newCols).toDF()
dataset.withColumns(outputColumns, newCols).toDF()
}

override def transformSchema(schema: StructType): StructType = {
Expand Down
Loading