From b14fbab7487a8464ba2a53bb9804e00fd14d3785 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Mon, 10 Oct 2016 10:33:09 +0800 Subject: [PATCH 1/4] [SPARK-17219][ML] enchance NaN value handling in Bucketizer This PR is an enhancement of PR with commit ID:57dc326bd00cf0a49da971e9c573c48ae28acaa2. We provided user when dealing NaN value in the dataset with 3 options, to either reserve an extra bucket for NaN values, or remove the NaN values, or report an error, by setting "keep", "skip", or "error"(default) to handleInvalid. '''Before: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) '''After: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) .setHandleNaN("skip") Signed-off-by: VinceShieh --- docs/ml-features.md | 8 ++- .../apache/spark/ml/feature/Bucketizer.scala | 63 ++++++++++++++++--- .../ml/feature/QuantileDiscretizer.scala | 43 +++++++++++-- .../spark/ml/feature/BucketizerSuite.scala | 8 +-- .../ml/feature/QuantileDiscretizerSuite.scala | 26 +++++--- python/pyspark/ml/feature.py | 8 ++- .../apache/spark/sql/DataFrameStatSuite.scala | 4 ++ 7 files changed, 127 insertions(+), 33 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a7f710fa52e64..948d8f29a193f 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1104,9 +1104,11 @@ for more details on the API. `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins is set by the `numBuckets` parameter. It is possible that the number of buckets used will be less than this value, for example, if there are too few -distinct values of the input to create enough distinct quantiles. Note also that NaN values are -handled specially and placed into their own bucket. For example, if 4 buckets are used, then -non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. +distinct values of the input to create enough distinct quantiles. Note also that QuantileDiscretizer +will raise an error when it finds NaN value in the dataset, but user can also choose to either +keep or remove NaN values within the dataset by setting handleInvalid. If user chooses to keep +NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets +are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ec0ea05f9e1b1..d04ab5cd7572d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -73,15 +74,52 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with + * invalid values), or error (which will throw an error), or keep (which will keep the invalid + * values in certain way). Default behaviour is to report an error for invalid entries. + * + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (which will filter out rows with invalid values), or" + + "error (which will throw an error), or keep (which will keep the invalid values" + + " in certain way). Default behaviour is to report an error for invalid entries.", + ParamValidators.inArray(Array("skip", "error", "keep"))) + + /** @group getParam */ + @Since("2.1.0") + def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { + case "keep" => Some(true) + case "skip" => Some(false) + case _ => None + } + + /** @group setParam */ + @Since("2.1.0") + def sethandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => - Bucketizer.binarySearchForBuckets($(splits), feature) + val keepInvalid = gethandleInvalid.isDefined && gethandleInvalid.get + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(dataset($(inputCol))) - val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol, newField.metadata) + val filteredDataset = { + if (!keepInvalid) { + // "skip" NaN option is set, will filter out NaN values in the dataset + dataset.na.drop.toDF() + } else { + dataset.toDF() + } + } + val newCol = bucketizer(filteredDataset($(inputCol))) + val newField = prepOutputField(filteredDataset.schema) + filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -126,10 +164,21 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** * Binary searching in several buckets to place each data point. + * @param splits array of split points + * @param feature data point + * @param keepInvalid NaN flag. + * Set "true" to make an extra bucket for NaN values; + * Set "false" to report an error for NaN values + * @return bucket for each data point * @throws SparkException if a feature is < splits.head or > splits.last */ - private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { - if (feature.isNaN) { + + private[feature] def binarySearchForBuckets( + splits: Array[Double], + feature: Double, + keepInvalid: Boolean): Double = { + if (feature.isNaN && keepInvalid) { + // NaN data point found plus "keep" NaN option is set splits.length - 1 } else if (feature == splits.last) { splits.length - 2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 05e034d90f6a3..31bf18f3741c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -66,11 +66,13 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there - * are too few distinct values of the input to create enough distinct quantiles. Note also that - * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets - * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special - * bucket(4). + * possible that the number of buckets used will be less than this value, for example, if there are + * too few distinct values of the input to create enough distinct quantiles. Note also that + * QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can + * also choose to either keep or remove NaN values within the dataset by setting handleInvalid. + * If user chooses to keep NaN values, they will be handled specially and placed into their own + * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], + * but NaNs will be counted in a special bucket[4]. * The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the @@ -100,6 +102,33 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with + * invalid values), or error (which will throw an error), or keep (which will keep the invalid + * values in certain way). Default behaviour is to report an error for invalid entries. + * + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (which will filter out rows with invalid values), or" + + "error (which will throw an error), or keep (which will keep the invalid values" + + " in certain way). Default behaviour is to report an error for invalid entries.", + ParamValidators.inArray(Array("skip", "error", "keep"))) + + /** @group getParam */ + @Since("2.1.0") + def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { + case "keep" => Some(true) + case "skip" => Some(false) + case _ => None + } + + /** @group setParam */ + @Since("2.1.0") + def sethandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkNumericType(schema, $(inputCol)) @@ -124,7 +153,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) + val bucketizer = new Bucketizer(uid) + .setSplits(distinctSplits.sorted) + .sethandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 87cdceb267387..5066238d06ce4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -98,6 +98,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("feature") .setOutputCol("result") .setSplits(splits) + .sethandleInvalid("keep") bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => @@ -111,8 +112,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa withClue("Invalid NaN split was not caught as an invalid split!") { intercept[IllegalArgumentException] { val bucketizer: Bucketizer = new Bucketizer() - .setInputCol("feature") - .setOutputCol("result") .setSplits(splits) } } @@ -138,7 +137,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val data = Array.fill(100)(Random.nextDouble()) val splits: Array[Double] = Double.NegativeInfinity +: Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity - val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x))) + val bsResult = Vectors.dense(data.map(x => + Bucketizer.binarySearchForBuckets(splits, x, false))) val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } @@ -169,7 +169,7 @@ private object BucketizerSuite extends SparkFunSuite { /** Check all values in splits, plus values between all splits. */ def checkBinarySearch(splits: Array[Double]): Unit = { def testFeature(feature: Double, expectedBucket: Double): Unit = { - assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket, + assert(Bucketizer.binarySearchForBuckets(splits, feature, false) === expectedBucket, s"Expected feature value $feature to be in bucket $expectedBucket with splits:" + s" ${splits.mkString(", ")}") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6822594044a56..7464bde5b3e5e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql._ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite @@ -76,20 +76,26 @@ class QuantileDiscretizerSuite import spark.implicits._ val numBuckets = 3 - val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN)) - .map(Tuple1.apply).toDF("input") + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - // Reserve extra one bucket for NaN - val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1 - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ + case(u, v) => + discretizer.sethandleInvalid(u) + val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") + val result = discretizer.fit(dataFrame).transform(dataFrame) + result.select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } } test("Test transform method on unseen data") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21caa616ec..469c963772760 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1157,9 +1157,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. It is possible that the number of buckets used will be less than this value, for example, if there are too few distinct values of the input to create enough distinct quantiles. Note also - that NaN values are handled specially and placed into their own bucket. For example, if 4 - buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in - a special bucket(4). + that QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user + can also choose to either keep or remove NaN values within the dataset by setting + handleInvalid. If user chooses to keep NaN values, they will be handled specially and placed + into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into + buckets[0-3], but NaNs will be counted in a special bucket[4]. The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 73026c749db45..726773ed93658 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -150,6 +150,10 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(d1 - 2 * q1 * n) < error_double) assert(math.abs(d2 - 2 * q2 * n) < error_double) } + // test approxQuantile on NaN values + val dfNaN = Array(Double.NaN, 1.0, Double.NaN, Double.NaN).toSeq.toDF("input") + val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons(0)) + assert(resNaN.count(_.isNaN) == 0) } test("crosstab") { From 5274d4a3703193a59607635f80eb9e3ebe61552c Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Tue, 25 Oct 2016 14:13:56 +0800 Subject: [PATCH 2/4] [SPARK-17219][ML] enchance NaN value handling in Bucketizer This PR is an enhancement of PR with commit ID:57dc326bd00cf0a49da971e9c573c48ae28acaa2. We provided user when dealing NaN value in the dataset with 3 options, to either reserve an extra bucket for NaN values, or remove the NaN values, or report an error, by setting "keep", "skip", or "error"(default) to handleInvalid. '''Before: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) '''After: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) .setHandleNaN("skip") Signed-off-by: VinceShieh --- .../apache/spark/ml/feature/Bucketizer.scala | 28 ++++++++----------- .../ml/feature/QuantileDiscretizer.scala | 8 ++---- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index d04ab5cd7572d..e568dc0556a57 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -77,8 +77,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** * Param for how to handle invalid entries. Options are skip (which will filter out rows with * invalid values), or error (which will throw an error), or keep (which will keep the invalid - * values in certain way). Default behaviour is to report an error for invalid entries. - * + * values in certain way). + * Default: "error" * @group param */ @Since("2.1.0") @@ -90,11 +90,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** @group getParam */ @Since("2.1.0") - def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { - case "keep" => Some(true) - case "skip" => Some(false) - case _ => None - } + def gethandleInvalid: String = $(handleInvalid) /** @group setParam */ @Since("2.1.0") @@ -104,19 +100,19 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val keepInvalid = gethandleInvalid.isDefined && gethandleInvalid.get - - val bucketizer: UserDefinedFunction = udf { (feature: Double) => - Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - } - val filteredDataset = { - if (!keepInvalid) { + val (filteredDataset, keepInvalid) = { + if ("skip" == gethandleInvalid) { // "skip" NaN option is set, will filter out NaN values in the dataset - dataset.na.drop.toDF() + (dataset.na.drop.toDF(), false) } else { - dataset.toDF() + (dataset.toDF(), "keep" == gethandleInvalid) } } + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) + } + val newCol = bucketizer(filteredDataset($(inputCol))) val newField = prepOutputField(filteredDataset.schema) filteredDataset.withColumn($(outputCol), newCol, newField.metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 31bf18f3741c1..5a90abba242f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -106,7 +106,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui * Param for how to handle invalid entries. Options are skip (which will filter out rows with * invalid values), or error (which will throw an error), or keep (which will keep the invalid * values in certain way). Default behaviour is to report an error for invalid entries. - * + * Default: "error" * @group param */ @Since("2.1.0") @@ -118,11 +118,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui /** @group getParam */ @Since("2.1.0") - def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { - case "keep" => Some(true) - case "skip" => Some(false) - case _ => None - } + def gethandleInvalid: String = $(handleInvalid) /** @group setParam */ @Since("2.1.0") From 2f98d31118413e61e1aa0431da402c41aa1ca5a6 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Wed, 26 Oct 2016 11:12:26 +0800 Subject: [PATCH 3/4] revert changes in feature.py Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 469c963772760..ee86207e7744a 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1155,13 +1155,6 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. - It is possible that the number of buckets used will be less than this value, for example, if - there are too few distinct values of the input to create enough distinct quantiles. Note also - that QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user - can also choose to either keep or remove NaN values within the dataset by setting - handleInvalid. If user chooses to keep NaN values, they will be handled specially and placed - into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into - buckets[0-3], but NaNs will be counted in a special bucket[4]. The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the From 2644235f111bbbf43fd1f30d24d318735553e034 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 26 Oct 2016 13:01:27 -0700 Subject: [PATCH 4/4] Cleanups: docs cleanups, slightly improved unit test coverage, fixed naming of set/get for handleInvalid --- docs/ml-features.md | 13 ++-- .../apache/spark/ml/feature/Bucketizer.scala | 44 +++++++++----- .../ml/feature/QuantileDiscretizer.scala | 60 ++++++++++--------- .../spark/ml/feature/BucketizerSuite.scala | 20 +++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 11 +++- .../apache/spark/sql/DataFrameStatSuite.scala | 6 +- 6 files changed, 97 insertions(+), 57 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 948d8f29a193f..64c6a160239cc 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1103,13 +1103,16 @@ for more details on the API. `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins is set by the `numBuckets` parameter. It is possible -that the number of buckets used will be less than this value, for example, if there are too few -distinct values of the input to create enough distinct quantiles. Note also that QuantileDiscretizer -will raise an error when it finds NaN value in the dataset, but user can also choose to either -keep or remove NaN values within the dataset by setting handleInvalid. If user chooses to keep +that the number of buckets used will be smaller than this value, for example, if there are too few +distinct values of the input to create enough distinct quantiles. + +NaN values: Note also that QuantileDiscretizer +will raise an error when it finds NaN values in the dataset, but the user can also choose to either +keep or remove NaN values within the dataset by setting `handleInvalid`. If the user chooses to keep NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. -The bin ranges are chosen using an approximate algorithm (see the documentation for + +Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the `relativeError` parameter. When set to zero, exact quantiles are calculated diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e568dc0556a57..1143f0f565ebd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -47,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * also includes y. Splits should be of length >= 3 and strictly increasing. * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * @group param */ @Since("1.4.0") @@ -75,37 +78,36 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with - * invalid values), or error (which will throw an error), or keep (which will keep the invalid - * values in certain way). + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). * Default: "error" * @group param */ @Since("2.1.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + - "invalid entries. Options are skip (which will filter out rows with invalid values), or" + - "error (which will throw an error), or keep (which will keep the invalid values" + - " in certain way). Default behaviour is to report an error for invalid entries.", - ParamValidators.inArray(Array("skip", "error", "keep"))) + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) /** @group getParam */ @Since("2.1.0") - def gethandleInvalid: String = $(handleInvalid) + def getHandleInvalid: String = $(handleInvalid) /** @group setParam */ @Since("2.1.0") - def sethandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val (filteredDataset, keepInvalid) = { - if ("skip" == gethandleInvalid) { + if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset - (dataset.na.drop.toDF(), false) + (dataset.na.drop().toDF(), false) } else { - (dataset.toDF(), "keep" == gethandleInvalid) + (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) } } @@ -140,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalid: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + /** * We require splits to be of length >= 3 and to be in strictly increasing order. * No NaN split should be accepted. @@ -173,9 +181,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { splits: Array[Double], feature: Double, keepInvalid: Boolean): Double = { - if (feature.isNaN && keepInvalid) { - // NaN data point found plus "keep" NaN option is set - splits.length - 1 + if (feature.isNaN) { + if (keepInvalid) { + splits.length - 1 + } else { + throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," + + " try setting Bucketizer.handleInvalid.") + } } else if (feature == splits.last) { splits.length - 2 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 5a90abba242f2..b9e01dde70d85 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * default: 2 * @group param */ @@ -61,19 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getRelativeError: Double = getOrDefault(relativeError) + + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + } /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there are - * too few distinct values of the input to create enough distinct quantiles. Note also that - * QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can - * also choose to either keep or remove NaN values within the dataset by setting handleInvalid. - * If user chooses to keep NaN values, they will be handled specially and placed into their own + * possible that the number of buckets used will be smaller than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. + * + * NaN handling: Note also that + * QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can + * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. + * If the user chooses to keep NaN values, they will be handled specially and placed into their own * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], * but NaNs will be counted in a special bucket[4]. - * The bin ranges are chosen using an approximate algorithm (see the documentation for + * + * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`, @@ -102,28 +127,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with - * invalid values), or error (which will throw an error), or keep (which will keep the invalid - * values in certain way). Default behaviour is to report an error for invalid entries. - * Default: "error" - * @group param - */ - @Since("2.1.0") - val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + - "invalid entries. Options are skip (which will filter out rows with invalid values), or" + - "error (which will throw an error), or keep (which will keep the invalid values" + - " in certain way). Default behaviour is to report an error for invalid entries.", - ParamValidators.inArray(Array("skip", "error", "keep"))) - - /** @group getParam */ - @Since("2.1.0") - def gethandleInvalid: String = $(handleInvalid) - /** @group setParam */ @Since("2.1.0") - def sethandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { @@ -151,7 +157,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui } val bucketizer = new Bucketizer(uid) .setSplits(distinctSplits.sorted) - .sethandleInvalid($(handleInvalid)) + .setHandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 5066238d06ce4..aac29137d7911 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -98,21 +98,33 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("feature") .setOutputCol("result") .setSplits(splits) - .sethandleInvalid("keep") + bucketizer.setHandleInvalid("keep") bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") } + + bucketizer.setHandleInvalid("skip") + val skipResults: Array[Double] = bucketizer.transform(dataFrame) + .select("result").as[Double].collect() + assert(skipResults.length === 7) + assert(skipResults.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } } test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) - withClue("Invalid NaN split was not caught as an invalid split!") { + withClue("Invalid NaN split was not caught during Bucketizer initialization") { intercept[IllegalArgumentException] { - val bucketizer: Bucketizer = new Bucketizer() - .setSplits(splits) + new Bucketizer().setSplits(splits) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 7464bde5b3e5e..f219f775b2186 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ @@ -85,9 +85,16 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(numBuckets) + withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { + val dataFrame: DataFrame = validData.toSeq.toDF("input") + intercept[SparkException] { + discretizer.fit(dataFrame).transform(dataFrame).collect() + } + } + List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ case(u, v) => - discretizer.sethandleInvalid(u) + discretizer.setHandleInvalid(u) val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") val result = discretizer.fit(dataFrame).transform(dataFrame) result.select("result", "expected").collect().foreach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 726773ed93658..1383208874a19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -151,9 +151,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(d2 - 2 * q2 * n) < error_double) } // test approxQuantile on NaN values - val dfNaN = Array(Double.NaN, 1.0, Double.NaN, Double.NaN).toSeq.toDF("input") - val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons(0)) - assert(resNaN.count(_.isNaN) == 0) + val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input") + val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head) + assert(resNaN.count(_.isNaN) === 0) } test("crosstab") {