[SPARK-13568] [ML] Create feature transformer to impute missing values#11601
[SPARK-13568] [ML] Create feature transformer to impute missing values#11601hhbyyh wants to merge 54 commits intoapache:masterfrom
Conversation
|
Test build #52734 has finished for PR 11601 at commit
|
| val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " + | ||
| "If mean, then replace missing values using the mean along the axis." + | ||
| "If median, then replace missing values using the median along the axis." + | ||
| "If most, then replace missing using the most frequent value along the axis.") |
There was a problem hiding this comment.
Could you add a param validation function since there are a limited number of valid strategies? You can add an attribute like val supportedMissingValueStrategies = Set("mean", "median", "most") to the Imputer companion object like is done here
There was a problem hiding this comment.
I added the validation to validateParameter. (which should be moved since it's the deprecated). Thanks for the suggestion. I'll add them.
|
Looking at the Jiras, it is unclear if any concrete decisions were made regarding handling Vectors and how NaN values should be handled in colStats. Is there any update? |
|
I prefer to keep Statistics.colStats(rdd) unchanged for now. As ut in this PR suggests, we can cover Double and Vector for now. |
|
Test build #52842 has finished for PR 11601 at commit
|
| val colStatistics = $(strategy) match { | ||
| case "mean" => | ||
| filteredDF.selectExpr(s"avg($colName)").first().getDouble(0) | ||
| case "median" => |
There was a problem hiding this comment.
I think we should favour using the new approxQuantile sql stat function here rather than computing exactly.
|
Test build #53923 has finished for PR 11601 at commit
|
|
Test build #53931 has finished for PR 11601 at commit
|
|
Test build #73268 has started for PR 11601 at commit |
|
Looks like CI was interrupted. |
| /** @group getParam */ | ||
| def getMissingValue: Double = $(missingValue) | ||
|
|
||
| /** |
There was a problem hiding this comment.
Fix comment indentation here.
| * All Null values in the input column are treated as missing, and so are also imputed. | ||
| */ | ||
| @Experimental | ||
| class Imputer @Since("2.1.0")(override val uid: String) |
There was a problem hiding this comment.
All @Since annotations -> 2.2.0
| /** | ||
| * Params for [[Imputer]] and [[ImputerModel]]. | ||
| */ | ||
| private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCol { |
There was a problem hiding this comment.
We don't use HasOutputCol anymore, correct?
There was a problem hiding this comment.
Sure, however I didn't get your first comment. Do you mean we should remove the import?
| object Imputer extends DefaultParamsReadable[Imputer] { | ||
|
|
||
| /** Set of strategy names that Imputer currently supports. */ | ||
| private[ml] val supportedStrategyNames = Set("mean", "median") |
There was a problem hiding this comment.
Could we factor out the mean and median names in to private[ml] val so to be used instead of the raw strings throughout?
| case "mean" => filtered.select(avg(inputCol)).first().getDouble(0) | ||
| case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0) | ||
| } | ||
| surrogate.asInstanceOf[Double] |
There was a problem hiding this comment.
is the asInstanceOf[Double] necessary here?
There was a problem hiding this comment.
no, will remove it.
| test("ImputerModel read/write") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
| val surrogateDF = Seq(1.234).toDF("myInputCol") |
There was a problem hiding this comment.
This should be "surrogate" col name - though I see we don't actually use it in load or transform
There was a problem hiding this comment.
this happens to be the correct column name for now.
There was a problem hiding this comment.
Ok - we should add a test here to check the column names of instance and newInstance match up? (The below check is just for the actual values of the surrogate, correct?
| var outputDF = dataset | ||
| val surrogates = surrogateDF.head().getSeq[Double](0) | ||
|
|
||
| $(inputCols).indices.foreach { i => |
There was a problem hiding this comment.
You could do $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), icSurrogate) => ...
| val localOutputCols = $(outputCols) | ||
| var outputSchema = schema | ||
|
|
||
| $(inputCols).indices.foreach { i => |
There was a problem hiding this comment.
Can do $(inputCols).zip($(outputCols)).foreach { case (inputCol, outputCol) => ...
| } | ||
| val surrogate = $(strategy) match { | ||
| case "mean" => filtered.select(avg(inputCol)).first().getDouble(0) | ||
| case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0) |
| * Model fitted by [[Imputer]]. | ||
| * | ||
| * @param surrogateDF Value by which missing values in the input columns will be replaced. This | ||
| * is stored using DataFrame with input column names and the corresponding surrogates. |
There was a problem hiding this comment.
This is misleading - you're just storing the array of surrogates... did you mean something different? Otherwise the comment must be changed,
There was a problem hiding this comment.
It sounds like you had the idea of storing the surrogates something like:
+------+---------+
|column|surrogate|
+------+---------+
| col1| 1.2|
| col2| 3.4|
| col3| 5.4|
+------+---------+
?
There was a problem hiding this comment.
I refactored it a little for better extensibility.
| inputCol1 | inputCol2 |
|---|---|
| surrogate1 | surrogate2 |
|
jenkins retest this please |
|
Test build #73753 has finished for PR 11601 at commit
|
|
Thanks a lot for making a pass @MLnick. The last update mainly focused on the interface and behavior change. I'll make a pass and also address your comments. |
|
Hi @MLnick I changed the surrogateDF format for better extensibility in the last update and added unit tests for multi-column support. Let me know if I miss anything.
|
|
Test build #73868 has finished for PR 11601 at commit
|
MLnick
left a comment
There was a problem hiding this comment.
Made a pass. A few minor comments.
| * The imputation strategy. | ||
| * If "mean", then replace missing values using the mean value of the feature. | ||
| * If "median", then replace missing values using the approximate median value of the | ||
| * feature (relative error less than 0.001). |
There was a problem hiding this comment.
I think remove the part (relative error less than 0.001).
This can be moved to the overall ScalaDoc for Imputer at L95.
| /** | ||
| * :: Experimental :: | ||
| * Imputation estimator for completing missing values, either using the mean or the median | ||
| * of the column in which the missing values are located. The input column should be of |
There was a problem hiding this comment.
As mentioned above at https://github.com/apache/spark/pull/11601/files#r104403880, you can add the note about relative error here.
Something like "For computing median, approxQuantile is used with a relative error of X" (provide a ScalaDoc link to approxQuantile).
There was a problem hiding this comment.
I didn't add the link as it may break java doc generation.
There was a problem hiding this comment.
Ah right - perhaps just mention using approxQuantile?
| @Since("2.2.0") | ||
| def setMissingValue(value: Double): this.type = set(missingValue, value) | ||
|
|
||
| import org.apache.spark.ml.feature.Imputer._ |
There was a problem hiding this comment.
This import should probably be above with the others (or within fit)
| } | ||
| val surrogate = $(strategy) match { | ||
| case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() | ||
| case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head |
There was a problem hiding this comment.
Not really sure about the relative error here - perhaps 0.01 is sufficient?
There was a problem hiding this comment.
Later perhaps we can even expose it as an expert param (but not for now)
There was a problem hiding this comment.
I tried it before. 0.01 and 0.001 actually takes the same time for even a large dataset. Agree we can make it a param later.
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
| var outputDF = dataset | ||
| val surrogates = surrogateDF.select($(inputCols).head, $(inputCols).tail: _*).head().toSeq |
There was a problem hiding this comment.
Maybe this is slightly cleaner: surrogateDF.select($(inputCols).map(col): _*)
| .setInputCols(Array("value1", "value2")) | ||
| .setOutputCols(Array("out1")) | ||
| .setStrategy(strategy) | ||
| intercept[IllegalArgumentException] { |
There was a problem hiding this comment.
Also test for thrown message here and withClue
| test("ImputerModel read/write") { | ||
| val spark = this.spark | ||
| import spark.implicits._ | ||
| val surrogateDF = Seq(1.234).toDF("myInputCol") |
There was a problem hiding this comment.
Ok - we should add a test here to check the column names of instance and newInstance match up? (The below check is just for the actual values of the surrogate, correct?
|
|
||
| } | ||
|
|
||
| object ImputerSuite{ |
| Seq("mean", "median").foreach { strategy => | ||
| val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) | ||
| .setStrategy(strategy) | ||
| intercept[SparkException] { |
| )).toDF("id", "value1", "value2", "value3") | ||
| Seq("mean", "median").foreach { strategy => | ||
| // inputCols and outCols length different | ||
| val imputer = new Imputer() |
There was a problem hiding this comment.
You can also perhaps use withClue to put a message for the subtest / exception assertion (e.g. withClue("Imputer should fail if inputCols and outputCols are different length")
|
Test build #74038 has finished for PR 11601 at commit
|
| * Note that the mean/median value is computed after filtering out missing values. | ||
| * All Null values in the input column are treated as missing, and so are also imputed. | ||
| * All Null values in the input column are treated as missing, and so are also imputed. For | ||
| * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. |
There was a problem hiding this comment.
Ah I see it is here - nevermind
| val ic = col(inputCol) | ||
| val filtered = dataset.select(ic.cast(DoubleType)) | ||
| .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) | ||
| if(filtered.rdd.isEmpty()) { |
There was a problem hiding this comment.
I think we can do filtered.take(1).size == 0 which should be more efficient
| .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) | ||
| if(filtered.rdd.isEmpty()) { | ||
| throw new SparkException(s"surrogate cannot be computed. " + | ||
| s"All the values in $inputCol are Null, Nan or missingValue ($missingValue)") |
There was a problem hiding this comment.
($missingValue) -> ${$(missingValue)}?
|
Made a few last comments. LGTM. cc @sethah @jkbradley I am going to merge this for 2.2. Let me know if you have any final comments. |
|
By the way out of curiosity, I tested things out on a cluster (4x workers, 192 cores & 480GB RAM total), with 100 columns of 100 million doubles each, 1% not cached cached |
|
Test build #74216 has finished for PR 11601 at commit
|
|
Thanks @MLnick for being the Shepherd and providing consistent help on discussion and review. The performance test matches what I got from my local environment. |
|
jenkins retest this please |
|
Created SPARK-19969 to track doc and examples to be done for 2.2 release. I can help with this if you're tied up. |
|
Test build #74651 has finished for PR 11601 at commit
|
|
Merged to master. Thanks @hhbyyh and also everyone for reviews. |
What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-13568
It is quite common to encounter missing values in data sets. It would be useful to implement a Transformer that can impute missing data points, similar to e.g. Imputer in scikit-learn.
Initially, options for imputation could include mean, median and most frequent, but we could add various other approaches, where possible existing DataFrame code can be used (e.g. for approximate quantiles etc).
Currently this PR supports imputation for Double and Vector (null and NaN in Vector).
How was this patch tested?
new unit tests and manual test