-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22060][ML] Fix CrossValidator/TrainValidationSplit param persist/load bug #19278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -396,17 +396,24 @@ private[ml] object DefaultParamsReader { | |
|
|
||
| /** | ||
| * Extract Params from metadata, and set them in the instance. | ||
| * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. | ||
| * This works if all Params (except params included by `skipParams` list) implement | ||
| * [[org.apache.spark.ml.param.Param.jsonDecode()]]. | ||
| * | ||
| * The params included in `skipParams` won't be set. This is useful if some params don't | ||
|
||
| * implement [[org.apache.spark.ml.param.Param.jsonDecode()]] and need special handling. | ||
| * TODO: Move to [[Metadata]] method | ||
| */ | ||
| def getAndSetParams(instance: Params, metadata: Metadata): Unit = { | ||
| def getAndSetParams(instance: Params, metadata: Metadata, | ||
|
||
| skipParams: Option[List[String]] = None): Unit = { | ||
| implicit val format = DefaultFormats | ||
| metadata.params match { | ||
| case JObject(pairs) => | ||
| pairs.foreach { case (paramName, jsonValue) => | ||
| val param = instance.getParam(paramName) | ||
| val value = param.jsonDecode(compact(render(jsonValue))) | ||
| instance.set(param, value) | ||
| if (skipParams == None || !skipParams.get.contains(paramName)) { | ||
| val param = instance.getParam(paramName) | ||
| val value = param.jsonDecode(compact(render(jsonValue))) | ||
| instance.set(param, value) | ||
| } | ||
| } | ||
| case _ => | ||
| throw new IllegalArgumentException( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,12 +159,15 @@ class CrossValidatorSuite | |
| .setEvaluator(evaluator) | ||
| .setNumFolds(20) | ||
| .setEstimatorParamMaps(paramMaps) | ||
| .setSeed(42L) | ||
| .setParallelism(2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the test for the model too please
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
|
|
||
| val cv2 = testDefaultReadWrite(cv, testParams = false) | ||
|
|
||
| assert(cv.uid === cv2.uid) | ||
| assert(cv.getNumFolds === cv2.getNumFolds) | ||
| assert(cv.getSeed === cv2.getSeed) | ||
| assert(cv.getParallelism === cv2.getParallelism) | ||
|
|
||
| assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) | ||
| val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio | |
| import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput | ||
| import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} | ||
| import org.apache.spark.ml.linalg.Vectors | ||
| import org.apache.spark.ml.param.{ParamMap} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.param.shared.HasInputCol | ||
| import org.apache.spark.ml.regression.LinearRegression | ||
| import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} | ||
|
|
@@ -160,11 +160,13 @@ class TrainValidationSplitSuite | |
| .setTrainRatio(0.5) | ||
| .setEstimatorParamMaps(paramMaps) | ||
| .setSeed(42L) | ||
| .setParallelism(2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you update the test for the Model too please?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. The model do not own
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, you're right, thanks |
||
|
|
||
| val tvs2 = testDefaultReadWrite(tvs, testParams = false) | ||
|
|
||
| assert(tvs.getTrainRatio === tvs2.getTrainRatio) | ||
| assert(tvs.getSeed === tvs2.getSeed) | ||
| assert(tvs.getParallelism === tvs2.getParallelism) | ||
|
|
||
| ValidatorParamsSuiteHelpers | ||
| .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numFolds is no longer needed