Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
ParamDesc[Boolean]("collectSubModels", "whether to collect a list of sub-models trained " +
"during tuning",
ParamDesc[Boolean]("collectSubModels", "If set to false, then only the single best " +
"sub-model will be available after fitting. If set to true, then all sub-models will be " +
"available. Warning: For large models, collecting all sub-models can cause OOMs on the " +
"Spark driver.",
Some("false"), isExpertParam = true)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.ml.tuning

import java.io.IOException
import java.util.{List => JList}
import java.util.{List => JList, Locale}

import scala.collection.JavaConverters._
import scala.concurrent.Future
Expand Down Expand Up @@ -176,7 +175,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new CrossValidatorModel(uid, bestModel, metrics, subModels).setParent(this))
copyValues(new CrossValidatorModel(uid, bestModel, metrics)
.setSubModels(subModels).setParent(this))
}

@Since("1.4.0")
Expand Down Expand Up @@ -252,19 +252,29 @@ object CrossValidator extends MLReadable[CrossValidator] {
class CrossValidatorModel private[ml] (
@Since("1.4.0") override val uid: String,
@Since("1.2.0") val bestModel: Model[_],
@Since("1.5.0") val avgMetrics: Array[Double],
@Since("2.3.0") val subModels: Option[Array[Array[Model[_]]]])
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {

/** A Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = {
this(uid, bestModel, avgMetrics.asScala.toArray, null)
this(uid, bestModel, avgMetrics.asScala.toArray)
}

private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: Array[Double]) = {
this(uid, bestModel, avgMetrics, null)
private var _subModels: Option[Array[Array[Model[_]]]] = None

@Since("2.3.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only use Since annotations for public APIs

private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]])
: CrossValidatorModel = {
_subModels = subModels
this
}

@Since("2.3.0")
def subModels: Array[Array[Model[_]]] = _subModels.get
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add Scala doc. We'll need to explain what the inner and outer array are and which one corresponds to the ordering of estimatorParamsMaps.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you please add a better Exception message? If submodels are not available, then we should tell users to set the collectSubModels Param before fitting.


@Since("2.3.0")
def hasSubModels: Boolean = _subModels.isDefined

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -281,20 +291,13 @@ class CrossValidatorModel private[ml] (
val copied = new CrossValidatorModel(
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
avgMetrics.clone(),
CrossValidatorModel.copySubModels(subModels))
avgMetrics.clone()
).setSubModels(CrossValidatorModel.copySubModels(_subModels))
copyValues(copied, extra).setParent(parent)
}

@Since("1.6.0")
override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)

@Since("2.3.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String, persistSubModels: Boolean): Unit = {
write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter]
.persistSubModels(persistSubModels).save(path)
}
}

@Since("1.6.0")
Expand Down Expand Up @@ -325,14 +328,19 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {

ValidatorParams.validateParams(instance)

protected var shouldPersistSubModels: Boolean = false
protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false

/**
* Set option for persist sub models.
* Extra options for CrossValidatorModelWriter, current support "persistSubModels".
* if sub models exsit, the default value for option "persistSubModels" is "true".
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: exsit -> exist

*/
@Since("2.3.0")
def persistSubModels(persist: Boolean): this.type = {
shouldPersistSubModels = persist
override def option(key: String, value: String): this.type = {
key.toLowerCase(Locale.ROOT) match {
case "persistsubmodels" => shouldPersistSubModels = value.toBoolean
case _ => throw new IllegalArgumentException(
s"Illegal option ${key} for CrossValidatorModelWriter")
}
this
}

Expand All @@ -344,13 +352,13 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
if (shouldPersistSubModels) {
require(instance.subModels.isDefined, "Cannot get sub models to persist.")
require(instance.hasSubModels, "Cannot get sub models to persist.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message may be unclear. How about adding: "When persisting tuning models, you can only set persistSubModels to true if the tuning was done with collectSubModels set to true. To save the sub-models, try rerunning fitting with collectSubModels set to true."

val subModelsPath = new Path(path, "subModels")
for (splitIndex <- 0 until instance.getNumFolds) {
val splitPath = new Path(subModelsPath, splitIndex.toString)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about naming this with the string "fold":
splitIndex.toString --> "fold" + splitIndex.toString?

for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
instance.subModels.get(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
}
}
}
Expand Down Expand Up @@ -388,7 +396,8 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
Some(_subModels)
} else None

val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels)
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
.setSubModels(subModels)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

package org.apache.spark.ml.tuning

import java.io.IOException
import java.util.{List => JList}
import java.util.{List => JList, Locale}

import scala.collection.JavaConverters._
import scala.concurrent.Future
Expand Down Expand Up @@ -173,7 +172,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics, subModels).setParent(this))
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics)
.setSubModels(subModels).setParent(this))
}

@Since("1.5.0")
Expand Down Expand Up @@ -245,19 +245,29 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
class TrainValidationSplitModel private[ml] (
@Since("1.5.0") override val uid: String,
@Since("1.5.0") val bestModel: Model[_],
@Since("1.5.0") val validationMetrics: Array[Double],
@Since("2.3.0") val subModels: Option[Array[Model[_]]])
@Since("1.5.0") val validationMetrics: Array[Double])
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {

/** A Python-friendly auxiliary constructor. */
private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = {
this(uid, bestModel, validationMetrics.asScala.toArray, null)
this(uid, bestModel, validationMetrics.asScala.toArray)
}

private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: Array[Double]) = {
this(uid, bestModel, validationMetrics, null)
private var _subModels: Option[Array[Model[_]]] = None

@Since("2.3.0")
private[tuning] def setSubModels(subModels: Option[Array[Model[_]]])
: TrainValidationSplitModel = {
_subModels = subModels
this
}

@Since("2.3.0")
def subModels: Array[Model[_]] = _subModels.get

@Since("2.3.0")
def hasSubModels: Boolean = _subModels.isDefined

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -274,20 +284,13 @@ class TrainValidationSplitModel private[ml] (
val copied = new TrainValidationSplitModel (
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
validationMetrics.clone(),
TrainValidationSplitModel.copySubModels(subModels))
validationMetrics.clone()
).setSubModels(TrainValidationSplitModel.copySubModels(_subModels))
copyValues(copied, extra).setParent(parent)
}

@Since("2.0.0")
override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)

@Since("2.3.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String, persistSubModels: Boolean): Unit = {
write.asInstanceOf[TrainValidationSplitModel.TrainValidationSplitModelWriter]
.persistSubModels(persistSubModels).save(path)
}
}

@Since("2.0.0")
Expand Down Expand Up @@ -315,14 +318,19 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {

ValidatorParams.validateParams(instance)

protected var shouldPersistSubModels: Boolean = false
protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false

/**
* Set option for persist sub models.
* Extra options for TrainValidationSplitModelWriter, current support "persistSubModels".
* if sub models exsit, the default value for option "persistSubModels" is "true".
*/
@Since("2.3.0")
def persistSubModels(persist: Boolean): this.type = {
shouldPersistSubModels = persist
override def option(key: String, value: String): this.type = {
key.toLowerCase(Locale.ROOT) match {
case "persistsubmodels" => shouldPersistSubModels = value.toBoolean
case _ => throw new IllegalArgumentException(
s"Illegal option ${key} for TrainValidationSplitModelWriter")
}
this
}

Expand All @@ -334,11 +342,11 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
if (shouldPersistSubModels) {
require(instance.subModels.isDefined, "Cannot get sub models to persist.")
require(instance.hasSubModels, "Cannot get sub models to persist.")
val subModelsPath = new Path(path, "subModels")
for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
val modelPath = new Path(subModelsPath, paramIndex.toString).toString
instance.subModels.get(paramIndex).asInstanceOf[MLWritable].save(modelPath)
instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we clean up/remove the partially-persisted subModels if any of these save() calls fail? E.g. let's say we have four subModels and the first three save() calls succeed but the fourth fails - should we delete the folders for the first three submodels?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123 Actually I don't think we have to worry about this; Pipeline persistence doesn't clean up if a stage fails to persist (see Pipeline.scala)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, its a good point. But currently model saving code do not have some exception handling code. e.g, overwrite saving, when save failed, it do not recover the old directory.
I think these things can be done in separated PRs.
cc @jkbradley What' your opinion ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question about cleaning up partially saved models. I agree it'd be nice to do in the future, rather than now.

}
}
}
Expand Down Expand Up @@ -370,8 +378,8 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
Some(_subModels)
} else None

val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics,
subModels)
val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
.setSubModels(subModels)
model.set(model.estimator, estimator)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
Expand Down
7 changes: 7 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ abstract class MLWriter extends BaseReadWrite with Logging {
@Since("1.6.0")
protected def saveImpl(path: String): Unit

/**
* `option()` handles extra options. If subclasses need to support extra options, override this
* method.
*/
@Since("2.3.0")
def option(key: String, value: String): this.type = this
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than overriding this in each subclass, let's have this option() method collect the specified options in a map which is consumed by the subclass when saveImpl() is called.


/**
* Overwrites if the output path already exists.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ class CrossValidatorSuite
.build()
val eval = new BinaryClassificationEvaluator
val numFolds = 3
val subdirName = Identifiable.randomUID("testSubModels")
val subPath = new File(tempDir, subdirName)
val subPath = new File(tempDir, "testCrossValidatorSubModels")
val persistSubModelsPath = new File(subPath, "subModels").toString
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used


val cv = new CrossValidator()
Expand All @@ -213,24 +212,25 @@ class CrossValidatorSuite

val cvModel = cv.fit(dataset)

assert(cvModel.subModels.isDefined && cvModel.subModels.get.length == numFolds)
cvModel.subModels.get.foreach(array => assert(array.length == lrParamMaps.length))

val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath
cvModel.save(savingPathWithoutSubModels)
val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
assert(cvModel2.subModels.isEmpty)
assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds)
cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length))

// Test the default value for option "persistSubModel" to be "true"
val savingPathWithSubModels = new File(subPath, "cvModel3").getPath
cvModel.save(savingPathWithSubModels, persistSubModels = true)
cvModel.save(savingPathWithSubModels)
val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
assert(cvModel3.subModels.isDefined && cvModel3.subModels.get.length == numFolds)
cvModel3.subModels.get.foreach(array => assert(array.length == lrParamMaps.length))
assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds)
cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length))

val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath
cvModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels)
val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
assert(!cvModel2.hasSubModels)

for (i <- 0 until numFolds) {
for (j <- 0 until lrParamMaps.length) {
assert(cvModel.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid ===
cvModel3.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid)
assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid ===
cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ class TrainValidationSplitSuite
.addGrid(lr.maxIter, Array(0, 3))
.build()
val eval = new BinaryClassificationEvaluator
val subdirName = Identifiable.randomUID("testSubModels")
val subPath = new File(tempDir, subdirName)
val subPath = new File(tempDir, "testTrainValidationSplitSubModels")

val tvs = new TrainValidationSplit()
.setEstimator(lr)
Expand All @@ -204,21 +203,22 @@ class TrainValidationSplitSuite

val tvsModel = tvs.fit(dataset)

assert(tvsModel.subModels.isDefined && tvsModel.subModels.get.length == lrParamMaps.length)

val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath
tvsModel.save(savingPathWithoutSubModels)
val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
assert(tvsModel2.subModels.isEmpty)
assert(tvsModel.hasSubModels && tvsModel.subModels.length == lrParamMaps.length)

// Test the default value for option "persistSubModel" to be "true"
val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath
tvsModel.save(savingPathWithSubModels, persistSubModels = true)
tvsModel.save(savingPathWithSubModels)
val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
assert(tvsModel3.subModels.isDefined && tvsModel3.subModels.get.length == lrParamMaps.length)
assert(tvsModel3.hasSubModels && tvsModel3.subModels.length == lrParamMaps.length)

val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath
tvsModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels)
val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
assert(!tvsModel2.hasSubModels)

for (i <- 0 until lrParamMaps.length) {
assert(tvsModel.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid ===
tvsModel3.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid)
assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid ===
tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid)
}
}

Expand Down