-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set #19993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22799][ML] Bucketizer should throw exception if single- and multi-column params are both set #19993
Changes from 9 commits
8f3581c
bb0c0d2
9f56800
f593f5b
2ecdc73
26fe05e
64634b5
9872bfd
d0b8d06
b20fb91
09d652d
a0c0fed
25b9bd4
18bbf61
d9d25b0
8c162a3
7894609
ebc6d16
2bc5cb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ import org.json4s.jackson.JsonMethods._ | |
| import org.apache.spark.SparkException | ||
| import org.apache.spark.annotation.{DeveloperApi, Since} | ||
| import org.apache.spark.ml.linalg.{JsonMatrixConverter, JsonVectorConverter, Matrix, Vector} | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util.Identifiable | ||
|
|
||
| /** | ||
|
|
@@ -249,6 +250,31 @@ object ParamValidators { | |
| def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => | ||
| value.length > lowerBound | ||
| } | ||
|
|
||
| /** | ||
| * Checks that either inputCols and outputCols are set or inputCol and outputCol are set. If | ||
| * this is not true, an `IllegalArgumentException` is raised. | ||
| * @param model | ||
| */ | ||
| private[spark] def checkMultiColumnParams(model: Params): Unit = { | ||
| model match { | ||
| case m: HasInputCols with HasInputCol if m.isSet(m.inputCols) && m.isSet(m.inputCol) => | ||
| raiseIncompatibleParamsException("inputCols", "inputCol") | ||
| case m: HasOutputCols with HasInputCol if m.isSet(m.outputCols) && m.isSet(m.inputCol) => | ||
|
||
| raiseIncompatibleParamsException("outputCols", "inputCol") | ||
| case m: HasInputCols with HasOutputCol if m.isSet(m.inputCols) && m.isSet(m.outputCol) => | ||
|
||
| raiseIncompatibleParamsException("inputCols", "outputCol") | ||
| case m: HasOutputCols with HasOutputCol if m.isSet(m.outputCols) && m.isSet(m.outputCol) => | ||
| raiseIncompatibleParamsException("outputCols", "outputCol") | ||
| case _ => | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we need to check other exclusive params, e.g., def checkExclusiveParams(model: Params, params: String*): Unit = {
if (params.filter(model.isSet(_)).size > 1) {
val paramString = params.mkString("`", "`, `", "`")
throw new IllegalArgumentException(s"$paramString are exclusive, but more than one among them are set.")
}
}
ParamValidators.checkExclusiveParams(this, "inputCol", "inputCols")
ParamValidators.checkExclusiveParams(this, "outputCol", "outputCols")
ParamValidators.checkExclusiveParams(this, "inputCol", "splitsArray")
ParamValidators.checkExclusiveParams(this, "inputCols", "splits")
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this method too in #20146.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can use that method once merged, thanks.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if #20146 will get merged for
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on #20146 (comment) from @WeichenXu123, I think #20146 cannot get merged for 2.3.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this method looks good to you, maybe you can just copy it from #20146 to use here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MLnick @viirya in order to address https://github.com/apache/spark/pull/19993/files#r161682506, I was thinking to let this method as it is (just renaming it as per @viirya suggestion) and only adding an
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think @viirya's method is simpler and more general, so why not use it?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } | ||
|
|
||
| private[spark] def raiseIncompatibleParamsException( | ||
| paramName1: String, | ||
| paramName2: String): Unit = { | ||
| throw new IllegalArgumentException(s"`$paramName1` and `$paramName2` cannot be both set.") | ||
|
||
| } | ||
| } | ||
|
|
||
| // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -401,15 +401,9 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa | |
| } | ||
| } | ||
|
|
||
| test("Both inputCol and inputCols are set") { | ||
| val bucket = new Bucketizer() | ||
| .setInputCol("feature1") | ||
| .setOutputCol("result") | ||
| .setSplits(Array(-0.5, 0.0, 0.5)) | ||
| .setInputCols(Array("feature1", "feature2")) | ||
|
|
||
| // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. | ||
| assert(bucket.isBucketizeMultipleColumns() == false) | ||
| test("assert exception is thrown is both multi-column and single-column params are set") { | ||
|
||
| val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") | ||
| ParamsSuite.testMultiColumnParams(classOf[Bucketizer], df) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,8 +20,11 @@ package org.apache.spark.ml.param | |
| import java.io.{ByteArrayOutputStream, ObjectOutputStream} | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.ml.{Estimator, Transformer} | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
|
||
| import org.apache.spark.ml.util.MyParams | ||
| import org.apache.spark.sql.Dataset | ||
|
|
||
| class ParamsSuite extends SparkFunSuite { | ||
|
|
||
|
|
@@ -430,4 +433,45 @@ object ParamsSuite extends SparkFunSuite { | |
| require(copyReturnType === obj.getClass, | ||
| s"${clazz.getName}.copy should return ${clazz.getName} instead of ${copyReturnType.getName}.") | ||
| } | ||
|
|
||
| /** | ||
| * Checks that the class throws an exception in case both `inputCols` and `inputCol` are set and | ||
| * in case both `outputCols` and `outputCol` are set. | ||
| * These checks are performed only whether the class extends respectively both `HasInputCols` and | ||
|
||
| * `HasInputCol` and both `HasOutputCols` and `HasOutputCol`. | ||
| * | ||
| * @param paramsClass The Class to be checked | ||
| * @param dataset A `Dataset` to use in the tests | ||
| */ | ||
| def testMultiColumnParams(paramsClass: Class[_ <: Params], dataset: Dataset[_]): Unit = { | ||
| val cols = dataset.columns | ||
|
|
||
| if (paramsClass.isAssignableFrom(classOf[HasInputCols]) | ||
| && paramsClass.isAssignableFrom(classOf[HasInputCol])) { | ||
| val model = paramsClass.newInstance() | ||
| model.set(model.asInstanceOf[HasInputCols].inputCols, cols) | ||
| model.set(model.asInstanceOf[HasInputCol].inputCol, cols(0)) | ||
| val e = intercept[IllegalArgumentException] { | ||
| model match { | ||
| case t: Transformer => t.transform(dataset) | ||
| case e: Estimator[_] => e.fit(dataset) | ||
| } | ||
| } | ||
| assert(e.getMessage.contains("cannot be both set")) | ||
| } | ||
|
|
||
| if (paramsClass.isAssignableFrom(classOf[HasOutputCols]) | ||
| && paramsClass.isAssignableFrom(classOf[HasOutputCol])) { | ||
| val model = paramsClass.newInstance() | ||
| model.set(model.asInstanceOf[HasOutputCols].outputCols, cols) | ||
| model.set(model.asInstanceOf[HasOutputCol].outputCol, cols(0)) | ||
| val e = intercept[IllegalArgumentException] { | ||
| model match { | ||
| case t: Transformer => t.transform(dataset) | ||
| case e: Estimator[_] => e.fit(dataset) | ||
| } | ||
| } | ||
| assert(e.getMessage.contains("cannot be both set")) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems superfluous to how have a separate method for this