-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21087] [ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala #19208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21087] [ML] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala #19208
Changes from 2 commits
46d3ab3
ae13440
e0f4ce6
a33c4ea
e009ee1
931fa6c
2a83fb5
f2ef609
7bacfca
654e4d5
7e997da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,8 +17,7 @@ | |
|
|
||
| package org.apache.spark.ml.tuning | ||
|
|
||
| import java.io.IOException | ||
| import java.util.{List => JList} | ||
| import java.util.{List => JList, Locale} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.concurrent.Future | ||
|
|
@@ -176,7 +175,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) | |
| logInfo(s"Best cross-validation metric: $bestMetric.") | ||
| val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] | ||
| instr.logSuccess(bestModel) | ||
| copyValues(new CrossValidatorModel(uid, bestModel, metrics, subModels).setParent(this)) | ||
| copyValues(new CrossValidatorModel(uid, bestModel, metrics) | ||
| .setSubModels(subModels).setParent(this)) | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
|
|
@@ -252,19 +252,29 @@ object CrossValidator extends MLReadable[CrossValidator] { | |
| class CrossValidatorModel private[ml] ( | ||
| @Since("1.4.0") override val uid: String, | ||
| @Since("1.2.0") val bestModel: Model[_], | ||
| @Since("1.5.0") val avgMetrics: Array[Double], | ||
| @Since("2.3.0") val subModels: Option[Array[Array[Model[_]]]]) | ||
| @Since("1.5.0") val avgMetrics: Array[Double]) | ||
| extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { | ||
|
|
||
| /** A Python-friendly auxiliary constructor. */ | ||
| private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { | ||
| this(uid, bestModel, avgMetrics.asScala.toArray, null) | ||
| this(uid, bestModel, avgMetrics.asScala.toArray) | ||
| } | ||
|
|
||
| private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: Array[Double]) = { | ||
| this(uid, bestModel, avgMetrics, null) | ||
| private var _subModels: Option[Array[Array[Model[_]]]] = None | ||
|
|
||
| @Since("2.3.0") | ||
| private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]]) | ||
| : CrossValidatorModel = { | ||
| _subModels = subModels | ||
| this | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| def subModels: Array[Array[Model[_]]] = _subModels.get | ||
|
||
|
|
||
| @Since("2.3.0") | ||
| def hasSubModels: Boolean = _subModels.isDefined | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
|
|
@@ -281,20 +291,13 @@ class CrossValidatorModel private[ml] ( | |
| val copied = new CrossValidatorModel( | ||
| uid, | ||
| bestModel.copy(extra).asInstanceOf[Model[_]], | ||
| avgMetrics.clone(), | ||
| CrossValidatorModel.copySubModels(subModels)) | ||
| avgMetrics.clone() | ||
| ).setSubModels(CrossValidatorModel.copySubModels(_subModels)) | ||
| copyValues(copied, extra).setParent(parent) | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) | ||
|
|
||
| @Since("2.3.0") | ||
| @throws[IOException]("If the input path already exists but overwrite is not enabled.") | ||
| def save(path: String, persistSubModels: Boolean): Unit = { | ||
| write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] | ||
| .persistSubModels(persistSubModels).save(path) | ||
| } | ||
| } | ||
|
|
||
| @Since("1.6.0") | ||
|
|
@@ -325,14 +328,19 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | |
|
|
||
| ValidatorParams.validateParams(instance) | ||
|
|
||
| protected var shouldPersistSubModels: Boolean = false | ||
| protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false | ||
|
|
||
| /** | ||
| * Set option for persist sub models. | ||
| * Extra options for CrossValidatorModelWriter, current support "persistSubModels". | ||
| * if sub models exsit, the default value for option "persistSubModels" is "true". | ||
|
||
| */ | ||
| @Since("2.3.0") | ||
| def persistSubModels(persist: Boolean): this.type = { | ||
| shouldPersistSubModels = persist | ||
| override def option(key: String, value: String): this.type = { | ||
| key.toLowerCase(Locale.ROOT) match { | ||
| case "persistsubmodels" => shouldPersistSubModels = value.toBoolean | ||
| case _ => throw new IllegalArgumentException( | ||
| s"Illegal option ${key} for CrossValidatorModelWriter") | ||
| } | ||
| this | ||
| } | ||
|
|
||
|
|
@@ -344,13 +352,13 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | |
| val bestModelPath = new Path(path, "bestModel").toString | ||
| instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) | ||
| if (shouldPersistSubModels) { | ||
| require(instance.subModels.isDefined, "Cannot get sub models to persist.") | ||
| require(instance.hasSubModels, "Cannot get sub models to persist.") | ||
|
||
| val subModelsPath = new Path(path, "subModels") | ||
| for (splitIndex <- 0 until instance.getNumFolds) { | ||
| val splitPath = new Path(subModelsPath, splitIndex.toString) | ||
|
||
| for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { | ||
| val modelPath = new Path(splitPath, paramIndex.toString).toString | ||
| instance.subModels.get(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) | ||
| instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) | ||
| } | ||
| } | ||
| } | ||
|
|
@@ -388,7 +396,8 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { | |
| Some(_subModels) | ||
| } else None | ||
|
|
||
| val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels) | ||
| val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) | ||
| .setSubModels(subModels) | ||
| model.set(model.estimator, estimator) | ||
| .set(model.evaluator, evaluator) | ||
| .set(model.estimatorParamMaps, estimatorParamMaps) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,8 +17,7 @@ | |
|
|
||
| package org.apache.spark.ml.tuning | ||
|
|
||
| import java.io.IOException | ||
| import java.util.{List => JList} | ||
| import java.util.{List => JList, Locale} | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.concurrent.Future | ||
|
|
@@ -173,7 +172,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St | |
| logInfo(s"Best train validation split metric: $bestMetric.") | ||
| val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] | ||
| instr.logSuccess(bestModel) | ||
| copyValues(new TrainValidationSplitModel(uid, bestModel, metrics, subModels).setParent(this)) | ||
| copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) | ||
| .setSubModels(subModels).setParent(this)) | ||
| } | ||
|
|
||
| @Since("1.5.0") | ||
|
|
@@ -245,19 +245,29 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { | |
| class TrainValidationSplitModel private[ml] ( | ||
| @Since("1.5.0") override val uid: String, | ||
| @Since("1.5.0") val bestModel: Model[_], | ||
| @Since("1.5.0") val validationMetrics: Array[Double], | ||
| @Since("2.3.0") val subModels: Option[Array[Model[_]]]) | ||
| @Since("1.5.0") val validationMetrics: Array[Double]) | ||
| extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { | ||
|
|
||
| /** A Python-friendly auxiliary constructor. */ | ||
| private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { | ||
| this(uid, bestModel, validationMetrics.asScala.toArray, null) | ||
| this(uid, bestModel, validationMetrics.asScala.toArray) | ||
| } | ||
|
|
||
| private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: Array[Double]) = { | ||
| this(uid, bestModel, validationMetrics, null) | ||
| private var _subModels: Option[Array[Model[_]]] = None | ||
|
|
||
| @Since("2.3.0") | ||
| private[tuning] def setSubModels(subModels: Option[Array[Model[_]]]) | ||
| : TrainValidationSplitModel = { | ||
| _subModels = subModels | ||
| this | ||
| } | ||
|
|
||
| @Since("2.3.0") | ||
| def subModels: Array[Model[_]] = _subModels.get | ||
|
|
||
| @Since("2.3.0") | ||
| def hasSubModels: Boolean = _subModels.isDefined | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema, logging = true) | ||
|
|
@@ -274,20 +284,13 @@ class TrainValidationSplitModel private[ml] ( | |
| val copied = new TrainValidationSplitModel ( | ||
| uid, | ||
| bestModel.copy(extra).asInstanceOf[Model[_]], | ||
| validationMetrics.clone(), | ||
| TrainValidationSplitModel.copySubModels(subModels)) | ||
| validationMetrics.clone() | ||
| ).setSubModels(TrainValidationSplitModel.copySubModels(_subModels)) | ||
| copyValues(copied, extra).setParent(parent) | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) | ||
|
|
||
| @Since("2.3.0") | ||
| @throws[IOException]("If the input path already exists but overwrite is not enabled.") | ||
| def save(path: String, persistSubModels: Boolean): Unit = { | ||
| write.asInstanceOf[TrainValidationSplitModel.TrainValidationSplitModelWriter] | ||
| .persistSubModels(persistSubModels).save(path) | ||
| } | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
|
|
@@ -315,14 +318,19 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { | |
|
|
||
| ValidatorParams.validateParams(instance) | ||
|
|
||
| protected var shouldPersistSubModels: Boolean = false | ||
| protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false | ||
|
|
||
| /** | ||
| * Set option for persist sub models. | ||
| * Extra options for TrainValidationSplitModelWriter, current support "persistSubModels". | ||
| * if sub models exsit, the default value for option "persistSubModels" is "true". | ||
| */ | ||
| @Since("2.3.0") | ||
| def persistSubModels(persist: Boolean): this.type = { | ||
| shouldPersistSubModels = persist | ||
| override def option(key: String, value: String): this.type = { | ||
| key.toLowerCase(Locale.ROOT) match { | ||
| case "persistsubmodels" => shouldPersistSubModels = value.toBoolean | ||
| case _ => throw new IllegalArgumentException( | ||
| s"Illegal option ${key} for TrainValidationSplitModelWriter") | ||
| } | ||
| this | ||
| } | ||
|
|
||
|
|
@@ -334,11 +342,11 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { | |
| val bestModelPath = new Path(path, "bestModel").toString | ||
| instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) | ||
| if (shouldPersistSubModels) { | ||
| require(instance.subModels.isDefined, "Cannot get sub models to persist.") | ||
| require(instance.hasSubModels, "Cannot get sub models to persist.") | ||
| val subModelsPath = new Path(path, "subModels") | ||
| for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { | ||
| val modelPath = new Path(subModelsPath, paramIndex.toString).toString | ||
| instance.subModels.get(paramIndex).asInstanceOf[MLWritable].save(modelPath) | ||
| instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we clean up/remove the partially-persisted
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WeichenXu123 Actually I don't think we have to worry about this; Pipeline persistence doesn't clean up if a stage fails to persist (see Pipeline.scala)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, its a good point. But currently model saving code do not have some exception handling code. e.g, overwrite saving, when save failed, it do not recover the old directory.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question about cleaning up partially saved models. I agree it'd be nice to do in the future, rather than now. |
||
| } | ||
| } | ||
| } | ||
|
|
@@ -370,8 +378,8 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { | |
| Some(_subModels) | ||
| } else None | ||
|
|
||
| val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics, | ||
| subModels) | ||
| val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) | ||
| .setSubModels(subModels) | ||
| model.set(model.estimator, estimator) | ||
| .set(model.evaluator, evaluator) | ||
| .set(model.estimatorParamMaps, estimatorParamMaps) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -107,6 +107,13 @@ abstract class MLWriter extends BaseReadWrite with Logging { | |
| @Since("1.6.0") | ||
| protected def saveImpl(path: String): Unit | ||
|
|
||
| /** | ||
| * `option()` handles extra options. If subclasses need to support extra options, override this | ||
| * method. | ||
| */ | ||
| @Since("2.3.0") | ||
| def option(key: String, value: String): this.type = this | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than overriding this in each subclass, let's have this option() method collect the specified options in a map which is consumed by the subclass when saveImpl() is called. |
||
|
|
||
| /** | ||
| * Overwrites if the output path already exists. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -199,8 +199,7 @@ class CrossValidatorSuite | |
| .build() | ||
| val eval = new BinaryClassificationEvaluator | ||
| val numFolds = 3 | ||
| val subdirName = Identifiable.randomUID("testSubModels") | ||
| val subPath = new File(tempDir, subdirName) | ||
| val subPath = new File(tempDir, "testCrossValidatorSubModels") | ||
| val persistSubModelsPath = new File(subPath, "subModels").toString | ||
|
||
|
|
||
| val cv = new CrossValidator() | ||
|
|
@@ -213,24 +212,25 @@ class CrossValidatorSuite | |
|
|
||
| val cvModel = cv.fit(dataset) | ||
|
|
||
| assert(cvModel.subModels.isDefined && cvModel.subModels.get.length == numFolds) | ||
| cvModel.subModels.get.foreach(array => assert(array.length == lrParamMaps.length)) | ||
|
|
||
| val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath | ||
| cvModel.save(savingPathWithoutSubModels) | ||
| val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) | ||
| assert(cvModel2.subModels.isEmpty) | ||
| assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds) | ||
| cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) | ||
|
|
||
| // Test the default value for option "persistSubModel" to be "true" | ||
| val savingPathWithSubModels = new File(subPath, "cvModel3").getPath | ||
| cvModel.save(savingPathWithSubModels, persistSubModels = true) | ||
| cvModel.save(savingPathWithSubModels) | ||
| val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) | ||
| assert(cvModel3.subModels.isDefined && cvModel3.subModels.get.length == numFolds) | ||
| cvModel3.subModels.get.foreach(array => assert(array.length == lrParamMaps.length)) | ||
| assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds) | ||
| cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length)) | ||
|
|
||
| val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath | ||
| cvModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) | ||
| val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) | ||
| assert(!cvModel2.hasSubModels) | ||
|
|
||
| for (i <- 0 until numFolds) { | ||
| for (j <- 0 until lrParamMaps.length) { | ||
| assert(cvModel.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid === | ||
| cvModel3.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid) | ||
| assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === | ||
| cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only use Since annotations for public APIs