From 46d3ab3899c196311368b3383338b3d4e6d5aeaa Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 11 Sep 2017 21:28:53 +0800 Subject: [PATCH 01/10] init pr --- .../ml/param/shared/SharedParamsCodeGen.scala | 7 +- .../spark/ml/param/shared/sharedParams.scala | 34 +++++ .../spark/ml/tuning/CrossValidator.scala | 129 +++++++++++++++--- .../ml/tuning/TrainValidationSplit.scala | 120 +++++++++++++--- .../spark/ml/tuning/ValidatorParams.scala | 22 ++- .../org/apache/spark/ml/util/ReadWrite.scala | 11 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 69 +++++++++- .../ml/tuning/TrainValidationSplitSuite.scala | 60 +++++++- 8 files changed, 391 insertions(+), 61 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 1860fe836174..a6b9fca0a6d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -82,7 +82,12 @@ private[shared] object SharedParamsCodeGen { "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), - isValid = "ParamValidators.gtEq(2)", isExpertParam = true)) + isValid = "ParamValidators.gtEq(2)", isExpertParam = true), + ParamDesc[Boolean]("collectSubModels", "whether to collect sub models when tuning fitting", + Some("false"), isExpertParam = true), + ParamDesc[String]("persistSubModelsPath", "The path to persist sub models when " + + "tuning fitting", Some("\"\""), isExpertParam = true) + ) val code = genSharedParams(params) val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6061d9ca0a08..fc481c70c219 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -402,4 +402,38 @@ private[ml] trait HasAggregationDepth extends Params { /** @group expertGetParam */ final def getAggregationDepth: Int = $(aggregationDepth) } + +/** + * Trait for shared param collectSubModels (default: false). + */ +private[ml] trait HasCollectSubModels extends Params { + + /** + * Param for whether to collect sub models when tuning fitting. + * @group expertParam + */ + final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect sub models when tuning fitting") + + setDefault(collectSubModels, false) + + /** @group expertGetParam */ + final def getCollectSubModels: Boolean = $(collectSubModels) +} + +/** + * Trait for shared param persistSubModelsPath (default: ""). + */ +private[ml] trait HasPersistSubModelsPath extends Params { + + /** + * Param for The path to persist sub models when tuning fitting. + * @group expertParam + */ + final val persistSubModelsPath: Param[String] = new Param[String](this, "persistSubModelsPath", "The path to persist sub models when tuning fitting") + + setDefault(persistSubModelsPath, "") + + /** @group expertGetParam */ + final def getPersistSubModelsPath: String = $(persistSubModelsPath) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index ce2a3a2e4041..da60362a4daa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tuning +import java.io.IOException import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -31,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism, HasPersistSubModelsPath} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -67,7 +68,8 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { @Since("1.2.0") class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) extends Estimator[CrossValidatorModel] - with CrossValidatorParams with HasParallelism with MLWritable with Logging { + with CrossValidatorParams with HasParallelism with HasCollectSubModels + with HasPersistSubModelsPath with MLWritable with Logging { @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) @@ -101,6 +103,14 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** @group expertSetParam */ + @Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + + /** @group expertSetParam */ + @Since("2.3.0") + def setPersistSubModelsPath(value: String): this.type = set(persistSubModelsPath, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -117,6 +127,13 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) instr.logParams(numFolds, seed, parallelism) logTuningParams(instr) + val collectSubModelsParam = $(collectSubModels) + val persistSubModelsPathParam = $(persistSubModelsPath) + + var subModels: Array[Array[Model[_]]] = if (collectSubModelsParam) { + Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null)) + } else null + // Compute metrics for each model over each split val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => @@ -125,10 +142,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels(splitIndex)(paramIndex) = model + } + if (persistSubModelsPathParam.nonEmpty) { + val modelPath = new Path(new Path(persistSubModelsPathParam, splitIndex.toString), + paramIndex.toString).toString + model.asInstanceOf[MLWritable].save(modelPath) + } + model } (executionContext) } @@ -160,7 +186,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics, subModels).setParent(this)) } @Since("1.4.0") @@ -212,14 +238,12 @@ object CrossValidator extends MLReadable[CrossValidator] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] - new CrossValidator(metadata.uid) + val cv = new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setNumFolds(numFolds) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(cv, metadata, skipParams = List("estimatorParamMaps")) + cv } } } @@ -237,12 +261,17 @@ object CrossValidator extends MLReadable[CrossValidator] { class CrossValidatorModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("1.2.0") val bestModel: Model[_], - @Since("1.5.0") val avgMetrics: Array[Double]) + @Since("1.5.0") val avgMetrics: Array[Double], + @Since("2.3.0") val subModels: Array[Array[Model[_]]]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { /** A Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { - this(uid, bestModel, avgMetrics.asScala.toArray) + this(uid, bestModel, avgMetrics.asScala.toArray, null) + } + + private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: Array[Double]) = { + this(uid, bestModel, avgMetrics, null) } @Since("2.0.0") @@ -261,17 +290,40 @@ class CrossValidatorModel private[ml] ( val copied = new CrossValidatorModel( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - avgMetrics.clone()) + avgMetrics.clone(), + CrossValidatorModel.copySubModels(subModels)) copyValues(copied, extra).setParent(parent) } @Since("1.6.0") override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) + + @Since("2.3.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String, persistSubModels: Boolean): Unit = { + write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] + .persistSubModels(persistSubModels).save(path) + } } @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { + private[CrossValidatorModel] def copySubModels(subModels: Array[Array[Model[_]]]) = { + var copiedSubModels: Array[Array[Model[_]]] = null + if (subModels != null) { + val numFolds = subModels.length + val numParamMaps = subModels(0).length + copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null)) + for (i <- 0 until numFolds) { + for (j <- 0 until numParamMaps) { + copiedSubModels(i)(j) = subModels(i)(j).copy(ParamMap.empty).asInstanceOf[Model[_]] + } + } + } + copiedSubModels + } + @Since("1.6.0") override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader @@ -283,12 +335,35 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { ValidatorParams.validateParams(instance) + protected var shouldPersistSubModels: Boolean = false + + /** + * Set option for persist sub models. + */ + @Since("2.3.0") + def persistSubModels(persist: Boolean): this.type = { + shouldPersistSubModels = persist + this + } + override protected def saveImpl(path: String): Unit = { import org.json4s.JsonDSL._ - val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq + val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ + ("shouldPersistSubModels" -> shouldPersistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (shouldPersistSubModels) { + require(instance.subModels != null, "Cannot get sub models to persist.") + val subModelsPath = new Path(path, "subModels") + for (splitIndex <- 0 until instance.getNumFolds) { + val splitPath = new Path(subModelsPath, splitIndex.toString) + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } + } } } @@ -303,16 +378,32 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray - val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean] + + val subModels: Array[Array[Model[_]]] = if (shouldPersistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( + estimatorParamMaps.length)(null)) + for (splitIndex <- 0 until numFolds) { + val splitPath = new Path(subModelsPath, splitIndex.toString) + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(splitPath, paramIndex.toString).toString + _subModels(splitIndex)(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + } + _subModels + } else null + + val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.numFolds, numFolds) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) + model } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 16db0f5f12c7..35c3df6e751e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tuning +import java.io.IOException import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -32,7 +33,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasParallelism +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism, HasPersistSubModelsPath} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -66,7 +67,8 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { @Since("1.5.0") class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[TrainValidationSplitModel] - with TrainValidationSplitParams with HasParallelism with MLWritable with Logging { + with TrainValidationSplitParams with HasParallelism with HasCollectSubModels + with HasPersistSubModelsPath with MLWritable with Logging { @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) @@ -100,6 +102,14 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) + /** @group expertSetParam */ + @Since("2.3.0") + def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) + + /** @group expertSetParam */ + @Since("2.3.0") + def setPersistSubModelsPath(value: String): this.type = set(persistSubModelsPath, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema @@ -120,12 +130,27 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.cache() validationDataset.cache() + val collectSubModelsParam = $(collectSubModels) + val persistSubModelsPathParam = $(persistSubModelsPath) + + var subModels: Array[Model[_]] = if (collectSubModelsParam) { + Array.fill[Model[_]](epm.length)(null) + } else null + // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.map { paramMap => + val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Model[_]] { - val model = est.fit(trainingDataset, paramMap) - model.asInstanceOf[Model[_]] + val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] + + if (collectSubModelsParam) { + subModels(paramIndex) = model + } + if (persistSubModelsPathParam.nonEmpty) { + val modelPath = new Path(persistSubModelsPathParam, paramIndex.toString).toString + model.asInstanceOf[MLWritable].save(modelPath) + } + model } (executionContext) } @@ -157,7 +182,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics, subModels).setParent(this)) } @Since("1.5.0") @@ -207,14 +232,12 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] - new TrainValidationSplit(metadata.uid) + val tvs = new TrainValidationSplit(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setTrainRatio(trainRatio) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(tvs, metadata, skipParams = List("estimatorParamMaps")) + tvs } } } @@ -230,12 +253,17 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override val uid: String, @Since("1.5.0") val bestModel: Model[_], - @Since("1.5.0") val validationMetrics: Array[Double]) + @Since("1.5.0") val validationMetrics: Array[Double], + @Since("2.3.0") val subModels: Array[Model[_]]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { /** A Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { - this(uid, bestModel, validationMetrics.asScala.toArray) + this(uid, bestModel, validationMetrics.asScala.toArray, null) + } + + private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: Array[Double]) = { + this(uid, bestModel, validationMetrics, null) } @Since("2.0.0") @@ -254,17 +282,37 @@ class TrainValidationSplitModel private[ml] ( val copied = new TrainValidationSplitModel ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - validationMetrics.clone()) + validationMetrics.clone(), + TrainValidationSplitModel.copySubModels(subModels)) copyValues(copied, extra).setParent(parent) } @Since("2.0.0") override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + + @Since("2.3.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String, persistSubModels: Boolean): Unit = { + write.asInstanceOf[TrainValidationSplitModel.TrainValidationSplitModelWriter] + .persistSubModels(persistSubModels).save(path) + } } @Since("2.0.0") object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { + private[TrainValidationSplitModel] def copySubModels(subModels: Array[Model[_]]) = { + var copiedSubModels: Array[Model[_]] = null + if (subModels != null) { + val numParamMaps = subModels.length + copiedSubModels = Array.fill[Model[_]](numParamMaps)(null) + for (i <- 0 until numParamMaps) { + copiedSubModels(i) = subModels(i).copy(ParamMap.empty).asInstanceOf[Model[_]] + } + } + copiedSubModels + } + @Since("2.0.0") override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader @@ -276,12 +324,32 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { ValidatorParams.validateParams(instance) + protected var shouldPersistSubModels: Boolean = false + + /** + * Set option for persist sub models. + */ + @Since("2.3.0") + def persistSubModels(persist: Boolean): this.type = { + shouldPersistSubModels = persist + this + } + override protected def saveImpl(path: String): Unit = { import org.json4s.JsonDSL._ - val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq + val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~ + ("shouldPersistSubModels" -> shouldPersistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) + if (shouldPersistSubModels) { + require(instance.subModels != null, "Cannot get sub models to persist.") + val subModelsPath = new Path(path, "subModels") + for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) + } + } } } @@ -295,17 +363,29 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray - val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean] + + val subModels: Array[Model[_]] = if (shouldPersistSubModels) { + val subModelsPath = new Path(path, "subModels") + val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null) + for (paramIndex <- 0 until estimatorParamMaps.length) { + val modelPath = new Path(subModelsPath, paramIndex.toString).toString + _subModels(paramIndex) = + DefaultParamsReader.loadParamsInstance(modelPath, sc) + } + _subModels + } else null + + val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics, + subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.trainRatio, trainRatio) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, skipParams = List("estimatorParamMaps")) + model } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 0ab6eed95938..363304ef1014 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -150,20 +150,14 @@ private[ml] object ValidatorParams { }.toSeq )) - val validatorSpecificParams = instance match { - case cv: CrossValidatorParams => - List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds))) - case tvs: TrainValidationSplitParams => - List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio))) - case _ => - // This should not happen. - throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " + - instance.getClass.getCanonicalName) - } - - val jsonParams = validatorSpecificParams ++ List( - "estimatorParamMaps" -> parse(estimatorParamMapsJson), - "seed" -> parse(instance.seed.jsonEncode(instance.getSeed))) + val params = instance.extractParamMap().toSeq + val skipParams = List("estimator", "evaluator", "estimatorParamMaps") + val jsonParams = render(params + .filter { case ParamPair(p, v) => !skipParams.contains(p.name)} + .map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson)) + ) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 65f142cfbbcb..dcbba18ec26d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -399,14 +399,17 @@ private[ml] object DefaultParamsReader { * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. * TODO: Move to [[Metadata]] method */ - def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + def getAndSetParams(instance: Params, metadata: Metadata, + skipParams: List[String] = null): Unit = { implicit val format = DefaultFormats metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) + if (skipParams == null || !skipParams.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } } case _ => throw new IllegalArgumentException( diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index a8d4377cff2d..57d26f3f67d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite + import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput @@ -27,7 +30,7 @@ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -159,12 +162,19 @@ class CrossValidatorSuite .setEvaluator(evaluator) .setNumFolds(20) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) + .setParallelism(2) + .setCollectSubModels(true) + .setPersistSubModelsPath("cvSubModels") val cv2 = testDefaultReadWrite(cv, testParams = false) assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getSeed === cv2.getSeed) + assert(cv.getParallelism === cv2.getParallelism) + assert(cv.getCollectSubModels === cv2.getCollectSubModels) + assert(cv.getPersistSubModelsPath === cv2.getPersistSubModelsPath) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -184,6 +194,63 @@ class CrossValidatorSuite .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps) } + test("CrossValidator expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val numFolds = 3 + val subdirName = Identifiable.randomUID("testSubModels") + val subPath = new File(tempDir, subdirName) + val persistSubModelsPath = new File(subPath, "subModels").toString + + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(numFolds) + .setParallelism(1) + .setCollectSubModels(true) + .setPersistSubModelsPath(persistSubModelsPath) + + val cvModel = cv.fit(dataset) + + val subModels = Array.fill(numFolds)(Array.fill[LogisticRegressionModel]( + lrParamMaps.length)(null)) + for (i <- 0 until numFolds) { + val splitPath = new File(persistSubModelsPath, i.toString) + for (j <- 0 until lrParamMaps.length) { + val subModelPath = new File(splitPath, j.toString).toString + subModels(i)(j) = LogisticRegressionModel.load(subModelPath) + } + } + + assert(cvModel.subModels != null && cvModel.subModels.length == numFolds) + cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath + cvModel.save(savingPathWithoutSubModels) + val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + assert(cvModel2.subModels === null) + + val savingPathWithSubModels = new File(subPath, "cvModel3").getPath + cvModel.save(savingPathWithSubModels, persistSubModels = true) + val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) + assert(cvModel3.subModels != null && cvModel3.subModels.length == numFolds) + cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + for (i <- 0 until numFolds) { + for (j <- 0 until lrParamMaps.length) { + assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === + subModels(i)(j).uid) + assert(cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === + subModels(i)(j).uid) + } + } + } + test("read/write: CrossValidator with nested estimator") { val ova = new OneVsRest().setClassifier(new LogisticRegression) val evaluator = new MulticlassClassificationEvaluator() diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 74801733381c..40654896d942 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -17,16 +17,19 @@ package org.apache.spark.ml.tuning +import java.io.File + import org.apache.spark.SparkFunSuite + import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param.{ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -160,11 +163,17 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) .setSeed(42L) + .setParallelism(2) + .setCollectSubModels(true) + .setPersistSubModelsPath("tvsSubModels") val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) + assert(tvs.getParallelism === tvs2.getParallelism) + assert(tvs.getCollectSubModels === tvs2.getCollectSubModels) + assert(tvs.getPersistSubModelsPath === tvs2.getPersistSubModelsPath) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) @@ -179,6 +188,53 @@ class TrainValidationSplitSuite } } + test("TrainValidationSplit expose sub models") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 3)) + .build() + val eval = new BinaryClassificationEvaluator + val subdirName = Identifiable.randomUID("testSubModels") + val subPath = new File(tempDir, subdirName) + val persistSubModelsPath = new File(subPath, "subModels").toString + + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setParallelism(1) + .setCollectSubModels(true) + .setPersistSubModelsPath(persistSubModelsPath) + + val tvsModel = tvs.fit(dataset) + + val subModels = Array.fill[LogisticRegressionModel](lrParamMaps.length)(null) + for (i <- 0 until lrParamMaps.length) { + val subModelPath = new File(persistSubModelsPath, i.toString).toString + subModels(i) = LogisticRegressionModel.load(subModelPath) + } + + assert(tvsModel.subModels != null && tvsModel.subModels.length == lrParamMaps.length) + + val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath + tvsModel.save(savingPathWithoutSubModels) + val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + assert(tvsModel2.subModels === null) + + val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath + tvsModel.save(savingPathWithSubModels, persistSubModels = true) + val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) + assert(tvsModel3.subModels != null && tvsModel3.subModels.length == lrParamMaps.length) + + for (i <- 0 until lrParamMaps.length) { + assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === + subModels(i).uid) + assert(tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid === + subModels(i).uid) + } + } + test("read/write: TrainValidationSplit with nested estimator") { val ova = new OneVsRest() .setClassifier(new LogisticRegression) From ae13440fd2220e28b58df52836f55fe5ed77c43f Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 13 Sep 2017 00:05:16 +0800 Subject: [PATCH 02/10] fix style --- .../scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 1 - .../org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 57d26f3f67d1..0851e51e895a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.tuning import java.io.File import org.apache.spark.SparkFunSuite - import org.apache.spark.ml.{Estimator, Model, Pipeline} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 40654896d942..9f0a7de8c330 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.ml.tuning import java.io.File import org.apache.spark.SparkFunSuite - import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput From e0f4ce6b21e5179e10a4a8640a3dd1aa0038e291 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 14 Sep 2017 14:34:55 +0800 Subject: [PATCH 03/10] remove code for dump models to disk --- .../ml/param/shared/SharedParamsCodeGen.scala | 4 +--- .../spark/ml/param/shared/sharedParams.scala | 17 ----------------- .../apache/spark/ml/tuning/CrossValidator.scala | 14 ++------------ .../spark/ml/tuning/TrainValidationSplit.scala | 13 ++----------- .../spark/ml/tuning/CrossValidatorSuite.scala | 17 +---------------- .../ml/tuning/TrainValidationSplitSuite.scala | 14 +------------- 6 files changed, 7 insertions(+), 72 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a6b9fca0a6d2..ad2000bf2096 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -84,9 +84,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), isValid = "ParamValidators.gtEq(2)", isExpertParam = true), ParamDesc[Boolean]("collectSubModels", "whether to collect sub models when tuning fitting", - Some("false"), isExpertParam = true), - ParamDesc[String]("persistSubModelsPath", "The path to persist sub models when " + - "tuning fitting", Some("\"\""), isExpertParam = true) + Some("false"), isExpertParam = true) ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index fc481c70c219..0927a55d3bce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -419,21 +419,4 @@ private[ml] trait HasCollectSubModels extends Params { /** @group expertGetParam */ final def getCollectSubModels: Boolean = $(collectSubModels) } - -/** - * Trait for shared param persistSubModelsPath (default: ""). - */ -private[ml] trait HasPersistSubModelsPath extends Params { - - /** - * Param for The path to persist sub models when tuning fitting. - * @group expertParam - */ - final val persistSubModelsPath: Param[String] = new Param[String](this, "persistSubModelsPath", "The path to persist sub models when tuning fitting") - - setDefault(persistSubModelsPath, "") - - /** @group expertGetParam */ - final def getPersistSubModelsPath: String = $(persistSubModelsPath) -} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index da60362a4daa..2500b81067a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism, HasPersistSubModelsPath} +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -69,7 +69,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams { class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) extends Estimator[CrossValidatorModel] with CrossValidatorParams with HasParallelism with HasCollectSubModels - with HasPersistSubModelsPath with MLWritable with Logging { + with MLWritable with Logging { @Since("1.2.0") def this() = this(Identifiable.randomUID("cv")) @@ -107,10 +107,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) - /** @group expertSetParam */ - @Since("2.3.0") - def setPersistSubModelsPath(value: String): this.type = set(persistSubModelsPath, value) - @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -128,7 +124,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logTuningParams(instr) val collectSubModelsParam = $(collectSubModels) - val persistSubModelsPathParam = $(persistSubModelsPath) var subModels: Array[Array[Model[_]]] = if (collectSubModelsParam) { Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null)) @@ -149,11 +144,6 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) if (collectSubModelsParam) { subModels(splitIndex)(paramIndex) = model } - if (persistSubModelsPathParam.nonEmpty) { - val modelPath = new Path(new Path(persistSubModelsPathParam, splitIndex.toString), - paramIndex.toString).toString - model.asInstanceOf[MLWritable].save(modelPath) - } model } (executionContext) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 35c3df6e751e..7fb070215803 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism, HasPersistSubModelsPath} +import org.apache.spark.ml.param.shared.{HasCollectSubModels, HasParallelism} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -68,7 +68,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams { class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String) extends Estimator[TrainValidationSplitModel] with TrainValidationSplitParams with HasParallelism with HasCollectSubModels - with HasPersistSubModelsPath with MLWritable with Logging { + with MLWritable with Logging { @Since("1.5.0") def this() = this(Identifiable.randomUID("tvs")) @@ -106,10 +106,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) - /** @group expertSetParam */ - @Since("2.3.0") - def setPersistSubModelsPath(value: String): this.type = set(persistSubModelsPath, value) - @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema @@ -131,7 +127,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St validationDataset.cache() val collectSubModelsParam = $(collectSubModels) - val persistSubModelsPathParam = $(persistSubModelsPath) var subModels: Array[Model[_]] = if (collectSubModelsParam) { Array.fill[Model[_]](epm.length)(null) @@ -146,10 +141,6 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St if (collectSubModelsParam) { subModels(paramIndex) = model } - if (persistSubModelsPathParam.nonEmpty) { - val modelPath = new Path(persistSubModelsPathParam, paramIndex.toString).toString - model.asInstanceOf[MLWritable].save(modelPath) - } model } (executionContext) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 0851e51e895a..443f33bdef4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -164,7 +164,6 @@ class CrossValidatorSuite .setSeed(42L) .setParallelism(2) .setCollectSubModels(true) - .setPersistSubModelsPath("cvSubModels") val cv2 = testDefaultReadWrite(cv, testParams = false) @@ -173,7 +172,6 @@ class CrossValidatorSuite assert(cv.getSeed === cv2.getSeed) assert(cv.getParallelism === cv2.getParallelism) assert(cv.getCollectSubModels === cv2.getCollectSubModels) - assert(cv.getPersistSubModelsPath === cv2.getPersistSubModelsPath) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -212,20 +210,9 @@ class CrossValidatorSuite .setNumFolds(numFolds) .setParallelism(1) .setCollectSubModels(true) - .setPersistSubModelsPath(persistSubModelsPath) val cvModel = cv.fit(dataset) - val subModels = Array.fill(numFolds)(Array.fill[LogisticRegressionModel]( - lrParamMaps.length)(null)) - for (i <- 0 until numFolds) { - val splitPath = new File(persistSubModelsPath, i.toString) - for (j <- 0 until lrParamMaps.length) { - val subModelPath = new File(splitPath, j.toString).toString - subModels(i)(j) = LogisticRegressionModel.load(subModelPath) - } - } - assert(cvModel.subModels != null && cvModel.subModels.length == numFolds) cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) @@ -243,9 +230,7 @@ class CrossValidatorSuite for (i <- 0 until numFolds) { for (j <- 0 until lrParamMaps.length) { assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === - subModels(i)(j).uid) - assert(cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === - subModels(i)(j).uid) + cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 9f0a7de8c330..2b1a69035cdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -164,7 +164,6 @@ class TrainValidationSplitSuite .setSeed(42L) .setParallelism(2) .setCollectSubModels(true) - .setPersistSubModelsPath("tvsSubModels") val tvs2 = testDefaultReadWrite(tvs, testParams = false) @@ -172,7 +171,6 @@ class TrainValidationSplitSuite assert(tvs.getSeed === tvs2.getSeed) assert(tvs.getParallelism === tvs2.getParallelism) assert(tvs.getCollectSubModels === tvs2.getCollectSubModels) - assert(tvs.getPersistSubModelsPath === tvs2.getPersistSubModelsPath) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps) @@ -196,7 +194,6 @@ class TrainValidationSplitSuite val eval = new BinaryClassificationEvaluator val subdirName = Identifiable.randomUID("testSubModels") val subPath = new File(tempDir, subdirName) - val persistSubModelsPath = new File(subPath, "subModels").toString val tvs = new TrainValidationSplit() .setEstimator(lr) @@ -204,16 +201,9 @@ class TrainValidationSplitSuite .setEvaluator(eval) .setParallelism(1) .setCollectSubModels(true) - .setPersistSubModelsPath(persistSubModelsPath) val tvsModel = tvs.fit(dataset) - val subModels = Array.fill[LogisticRegressionModel](lrParamMaps.length)(null) - for (i <- 0 until lrParamMaps.length) { - val subModelPath = new File(persistSubModelsPath, i.toString).toString - subModels(i) = LogisticRegressionModel.load(subModelPath) - } - assert(tvsModel.subModels != null && tvsModel.subModels.length == lrParamMaps.length) val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath @@ -228,9 +218,7 @@ class TrainValidationSplitSuite for (i <- 0 until lrParamMaps.length) { assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === - subModels(i).uid) - assert(tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid === - subModels(i).uid) + tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid) } } From e009ee1145930a02c71db85c967a49f9fd7509e5 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Wed, 27 Sep 2017 22:59:33 +0800 Subject: [PATCH 04/10] address comment issues --- .../ml/param/shared/SharedParamsCodeGen.scala | 3 +- .../spark/ml/param/shared/sharedParams.scala | 4 +-- .../spark/ml/tuning/CrossValidator.scala | 29 +++++++++---------- .../ml/tuning/TrainValidationSplit.scala | 29 +++++++++---------- .../spark/ml/tuning/CrossValidatorSuite.scala | 14 ++++----- .../ml/tuning/TrainValidationSplitSuite.scala | 10 +++---- 6 files changed, 44 insertions(+), 45 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index ad2000bf2096..f94924b1afaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,7 +83,8 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), isValid = "ParamValidators.gtEq(2)", isExpertParam = true), - ParamDesc[Boolean]("collectSubModels", "whether to collect sub models when tuning fitting", + ParamDesc[Boolean]("collectSubModels", "whether to collect a list of sub-models trained " + + "during tuning", Some("false"), isExpertParam = true) ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 0927a55d3bce..7f1c4c6f559a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -409,10 +409,10 @@ private[ml] trait HasAggregationDepth extends Params { private[ml] trait HasCollectSubModels extends Params { /** - * Param for whether to collect sub models when tuning fitting. + * Param for whether to collect a list of sub-models trained during tuning. * @group expertParam */ - final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect sub models when tuning fitting") + final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning") setDefault(collectSubModels, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index cd6c292bfcfa..b258ebe14ccf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -125,9 +125,9 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val collectSubModelsParam = $(collectSubModels) - var subModels: Array[Array[Model[_]]] = if (collectSubModelsParam) { - Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null)) - } else null + var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) { + Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))) + } else None // Compute metrics for each model over each split val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) @@ -142,7 +142,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { - subModels(splitIndex)(paramIndex) = model + subModels.get(splitIndex)(paramIndex) = model } model } (executionContext) @@ -253,7 +253,7 @@ class CrossValidatorModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("1.2.0") val bestModel: Model[_], @Since("1.5.0") val avgMetrics: Array[Double], - @Since("2.3.0") val subModels: Array[Array[Model[_]]]) + @Since("2.3.0") val subModels: Option[Array[Array[Model[_]]]]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { /** A Python-friendly auxiliary constructor. */ @@ -300,19 +300,18 @@ class CrossValidatorModel private[ml] ( @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { - private[CrossValidatorModel] def copySubModels(subModels: Array[Array[Model[_]]]) = { - var copiedSubModels: Array[Array[Model[_]]] = null - if (subModels != null) { + private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) = { + subModels.map { subModels => val numFolds = subModels.length val numParamMaps = subModels(0).length - copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null)) + val copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null)) for (i <- 0 until numFolds) { for (j <- 0 until numParamMaps) { copiedSubModels(i)(j) = subModels(i)(j).copy(ParamMap.empty).asInstanceOf[Model[_]] } } + copiedSubModels } - copiedSubModels } @Since("1.6.0") @@ -345,13 +344,13 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) if (shouldPersistSubModels) { - require(instance.subModels != null, "Cannot get sub models to persist.") + require(instance.subModels.isDefined, "Cannot get sub models to persist.") val subModelsPath = new Path(path, "subModels") for (splitIndex <- 0 until instance.getNumFolds) { val splitPath = new Path(subModelsPath, splitIndex.toString) for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { val modelPath = new Path(splitPath, paramIndex.toString).toString - instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) + instance.subModels.get(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) } } } @@ -374,7 +373,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean] - val subModels: Array[Array[Model[_]]] = if (shouldPersistSubModels) { + val subModels: Option[Array[Array[Model[_]]]] = if (shouldPersistSubModels) { val subModelsPath = new Path(path, "subModels") val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( estimatorParamMaps.length)(null)) @@ -386,8 +385,8 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { DefaultParamsReader.loadParamsInstance(modelPath, sc) } } - _subModels - } else null + Some(_subModels) + } else None val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels) model.set(model.estimator, estimator) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c6e57fc237f0..7220281e22e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -128,9 +128,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val collectSubModelsParam = $(collectSubModels) - var subModels: Array[Model[_]] = if (collectSubModelsParam) { - Array.fill[Model[_]](epm.length)(null) - } else null + var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) { + Some(Array.fill[Model[_]](epm.length)(null)) + } else None // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") @@ -139,7 +139,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { - subModels(paramIndex) = model + subModels.get(paramIndex) = model } model } (executionContext) @@ -246,7 +246,7 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override val uid: String, @Since("1.5.0") val bestModel: Model[_], @Since("1.5.0") val validationMetrics: Array[Double], - @Since("2.3.0") val subModels: Array[Model[_]]) + @Since("2.3.0") val subModels: Option[Array[Model[_]]]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { /** A Python-friendly auxiliary constructor. */ @@ -293,16 +293,15 @@ class TrainValidationSplitModel private[ml] ( @Since("2.0.0") object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { - private[TrainValidationSplitModel] def copySubModels(subModels: Array[Model[_]]) = { - var copiedSubModels: Array[Model[_]] = null - if (subModels != null) { + private[TrainValidationSplitModel] def copySubModels(subModels: Option[Array[Model[_]]]) = { + subModels.map { subModels => val numParamMaps = subModels.length - copiedSubModels = Array.fill[Model[_]](numParamMaps)(null) + val copiedSubModels = Array.fill[Model[_]](numParamMaps)(null) for (i <- 0 until numParamMaps) { copiedSubModels(i) = subModels(i).copy(ParamMap.empty).asInstanceOf[Model[_]] } + copiedSubModels } - copiedSubModels } @Since("2.0.0") @@ -335,11 +334,11 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) if (shouldPersistSubModels) { - require(instance.subModels != null, "Cannot get sub models to persist.") + require(instance.subModels.isDefined, "Cannot get sub models to persist.") val subModelsPath = new Path(path, "subModels") for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { val modelPath = new Path(subModelsPath, paramIndex.toString).toString - instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) + instance.subModels.get(paramIndex).asInstanceOf[MLWritable].save(modelPath) } } } @@ -360,7 +359,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean] - val subModels: Array[Model[_]] = if (shouldPersistSubModels) { + val subModels: Option[Array[Model[_]]] = if (shouldPersistSubModels) { val subModelsPath = new Path(path, "subModels") val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null) for (paramIndex <- 0 until estimatorParamMaps.length) { @@ -368,8 +367,8 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { _subModels(paramIndex) = DefaultParamsReader.loadParamsInstance(modelPath, sc) } - _subModels - } else null + Some(_subModels) + } else None val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics, subModels) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 443f33bdef4c..ba83f673c243 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -213,24 +213,24 @@ class CrossValidatorSuite val cvModel = cv.fit(dataset) - assert(cvModel.subModels != null && cvModel.subModels.length == numFolds) - cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + assert(cvModel.subModels.isDefined && cvModel.subModels.get.length == numFolds) + cvModel.subModels.get.foreach(array => assert(array.length == lrParamMaps.length)) val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath cvModel.save(savingPathWithoutSubModels) val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) - assert(cvModel2.subModels === null) + assert(cvModel2.subModels.isEmpty) val savingPathWithSubModels = new File(subPath, "cvModel3").getPath cvModel.save(savingPathWithSubModels, persistSubModels = true) val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) - assert(cvModel3.subModels != null && cvModel3.subModels.length == numFolds) - cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + assert(cvModel3.subModels.isDefined && cvModel3.subModels.get.length == numFolds) + cvModel3.subModels.get.foreach(array => assert(array.length == lrParamMaps.length)) for (i <- 0 until numFolds) { for (j <- 0 until lrParamMaps.length) { - assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === - cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) + assert(cvModel.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid === + cvModel3.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 2b1a69035cdc..d9402ece132e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -204,21 +204,21 @@ class TrainValidationSplitSuite val tvsModel = tvs.fit(dataset) - assert(tvsModel.subModels != null && tvsModel.subModels.length == lrParamMaps.length) + assert(tvsModel.subModels.isDefined && tvsModel.subModels.get.length == lrParamMaps.length) val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath tvsModel.save(savingPathWithoutSubModels) val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) - assert(tvsModel2.subModels === null) + assert(tvsModel2.subModels.isEmpty) val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath tvsModel.save(savingPathWithSubModels, persistSubModels = true) val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) - assert(tvsModel3.subModels != null && tvsModel3.subModels.length == lrParamMaps.length) + assert(tvsModel3.subModels.isDefined && tvsModel3.subModels.get.length == lrParamMaps.length) for (i <- 0 until lrParamMaps.length) { - assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === - tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid) + assert(tvsModel.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid === + tvsModel3.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid) } } From 931fa6cff0b036028f110214b5a602d9a64323ef Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 3 Nov 2017 17:01:55 +0800 Subject: [PATCH 05/10] address issues from comments --- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/tuning/CrossValidator.scala | 56 ++++++++++-------- .../ml/tuning/TrainValidationSplit.scala | 57 +++++++++++-------- .../org/apache/spark/ml/util/ReadWrite.scala | 7 +++ .../spark/ml/tuning/CrossValidatorSuite.scala | 28 ++++----- .../ml/tuning/TrainValidationSplitSuite.scala | 24 ++++---- 6 files changed, 103 insertions(+), 75 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index f94924b1afaf..05ff6acea5b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,8 +83,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"), isValid = "ParamValidators.gtEq(2)", isExpertParam = true), - ParamDesc[Boolean]("collectSubModels", "whether to collect a list of sub-models trained " + - "during tuning", + ParamDesc[Boolean]("collectSubModels", "If set to false, then only the single best " + + "sub-model will be available after fitting. If set to true, then all sub-models will be " + + "available. Warning: For large models, collecting all sub-models can cause OOMs on the " + + "Spark driver.", Some("false"), isExpertParam = true) ) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index b258ebe14ccf..d4d3661ee624 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tuning import java.io.IOException -import java.util.{List => JList} +import java.util.{Locale, List => JList} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -176,7 +176,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new CrossValidatorModel(uid, bestModel, metrics, subModels).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.4.0") @@ -252,19 +253,29 @@ object CrossValidator extends MLReadable[CrossValidator] { class CrossValidatorModel private[ml] ( @Since("1.4.0") override val uid: String, @Since("1.2.0") val bestModel: Model[_], - @Since("1.5.0") val avgMetrics: Array[Double], - @Since("2.3.0") val subModels: Option[Array[Array[Model[_]]]]) + @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { /** A Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { - this(uid, bestModel, avgMetrics.asScala.toArray, null) + this(uid, bestModel, avgMetrics.asScala.toArray) } - private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: Array[Double]) = { - this(uid, bestModel, avgMetrics, null) + private var _subModels: Option[Array[Array[Model[_]]]] = None + + @Since("2.3.0") + private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]]) + : CrossValidatorModel = { + _subModels = subModels + this } + @Since("2.3.0") + def subModels: Array[Array[Model[_]]] = _subModels.get + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -281,20 +292,13 @@ class CrossValidatorModel private[ml] ( val copied = new CrossValidatorModel( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - avgMetrics.clone(), - CrossValidatorModel.copySubModels(subModels)) + avgMetrics.clone() + ).setSubModels(CrossValidatorModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("1.6.0") override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) - - @Since("2.3.0") - @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String, persistSubModels: Boolean): Unit = { - write.asInstanceOf[CrossValidatorModel.CrossValidatorModelWriter] - .persistSubModels(persistSubModels).save(path) - } } @Since("1.6.0") @@ -325,14 +329,19 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { ValidatorParams.validateParams(instance) - protected var shouldPersistSubModels: Boolean = false + protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false /** - * Set option for persist sub models. + * Extra options for CrossValidatorModelWriter, current support "persistSubModels". + * if sub models exsit, the default value for option "persistSubModels" is "true". */ @Since("2.3.0") - def persistSubModels(persist: Boolean): this.type = { - shouldPersistSubModels = persist + override def option(key: String, value: String): this.type = { + key.toLowerCase(Locale.ROOT) match { + case "persistsubmodels" => shouldPersistSubModels = value.toBoolean + case _ => throw new IllegalArgumentException( + s"Illegal option ${key} for CrossValidatorModelWriter") + } this } @@ -344,13 +353,13 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) if (shouldPersistSubModels) { - require(instance.subModels.isDefined, "Cannot get sub models to persist.") + require(instance.hasSubModels, "Cannot get sub models to persist.") val subModelsPath = new Path(path, "subModels") for (splitIndex <- 0 until instance.getNumFolds) { val splitPath = new Path(subModelsPath, splitIndex.toString) for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { val modelPath = new Path(splitPath, paramIndex.toString).toString - instance.subModels.get(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) + instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) } } } @@ -388,7 +397,8 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { Some(_subModels) } else None - val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels) + val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 7220281e22e1..a34895be0feb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.tuning import java.io.IOException -import java.util.{List => JList} +import java.util.{Locale, List => JList} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -173,7 +173,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) - copyValues(new TrainValidationSplitModel(uid, bestModel, metrics, subModels).setParent(this)) + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) + .setSubModels(subModels).setParent(this)) } @Since("1.5.0") @@ -245,19 +246,29 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override val uid: String, @Since("1.5.0") val bestModel: Model[_], - @Since("1.5.0") val validationMetrics: Array[Double], - @Since("2.3.0") val subModels: Option[Array[Model[_]]]) + @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { /** A Python-friendly auxiliary constructor. */ private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { - this(uid, bestModel, validationMetrics.asScala.toArray, null) + this(uid, bestModel, validationMetrics.asScala.toArray) } - private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: Array[Double]) = { - this(uid, bestModel, validationMetrics, null) + private var _subModels: Option[Array[Model[_]]] = None + + @Since("2.3.0") + private[tuning] def setSubModels(subModels: Option[Array[Model[_]]]) + : TrainValidationSplitModel = { + _subModels = subModels + this } + @Since("2.3.0") + def subModels: Array[Model[_]] = _subModels.get + + @Since("2.3.0") + def hasSubModels: Boolean = _subModels.isDefined + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -274,20 +285,13 @@ class TrainValidationSplitModel private[ml] ( val copied = new TrainValidationSplitModel ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], - validationMetrics.clone(), - TrainValidationSplitModel.copySubModels(subModels)) + validationMetrics.clone() + ).setSubModels(TrainValidationSplitModel.copySubModels(_subModels)) copyValues(copied, extra).setParent(parent) } @Since("2.0.0") override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) - - @Since("2.3.0") - @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String, persistSubModels: Boolean): Unit = { - write.asInstanceOf[TrainValidationSplitModel.TrainValidationSplitModelWriter] - .persistSubModels(persistSubModels).save(path) - } } @Since("2.0.0") @@ -315,14 +319,19 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { ValidatorParams.validateParams(instance) - protected var shouldPersistSubModels: Boolean = false + protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false /** - * Set option for persist sub models. + * Extra options for TrainValidationSplitModelWriter, current support "persistSubModels". + * if sub models exsit, the default value for option "persistSubModels" is "true". */ @Since("2.3.0") - def persistSubModels(persist: Boolean): this.type = { - shouldPersistSubModels = persist + override def option(key: String, value: String): this.type = { + key.toLowerCase(Locale.ROOT) match { + case "persistsubmodels" => shouldPersistSubModels = value.toBoolean + case _ => throw new IllegalArgumentException( + s"Illegal option ${key} for TrainValidationSplitModelWriter") + } this } @@ -334,11 +343,11 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) if (shouldPersistSubModels) { - require(instance.subModels.isDefined, "Cannot get sub models to persist.") + require(instance.hasSubModels, "Cannot get sub models to persist.") val subModelsPath = new Path(path, "subModels") for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { val modelPath = new Path(subModelsPath, paramIndex.toString).toString - instance.subModels.get(paramIndex).asInstanceOf[MLWritable].save(modelPath) + instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath) } } } @@ -370,8 +379,8 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { Some(_subModels) } else None - val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics, - subModels) + val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + .setSubModels(subModels) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 7188da353126..dafc4326bdae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -107,6 +107,13 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * `option()` handles extra options. If subclasses need to support extra options, override this + * method. + */ + @Since("2.3.0") + def option(key: String, value: String): this.type = this + /** * Overwrites if the output path already exists. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index ba83f673c243..4a6a4ac509fe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -199,8 +199,7 @@ class CrossValidatorSuite .build() val eval = new BinaryClassificationEvaluator val numFolds = 3 - val subdirName = Identifiable.randomUID("testSubModels") - val subPath = new File(tempDir, subdirName) + val subPath = new File(tempDir, "testCrossValidatorSubModels") val persistSubModelsPath = new File(subPath, "subModels").toString val cv = new CrossValidator() @@ -213,24 +212,25 @@ class CrossValidatorSuite val cvModel = cv.fit(dataset) - assert(cvModel.subModels.isDefined && cvModel.subModels.get.length == numFolds) - cvModel.subModels.get.foreach(array => assert(array.length == lrParamMaps.length)) - - val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath - cvModel.save(savingPathWithoutSubModels) - val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) - assert(cvModel2.subModels.isEmpty) + assert(cvModel.hasSubModels && cvModel.subModels.length == numFolds) + cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + // Test the default value for option "persistSubModel" to be "true" val savingPathWithSubModels = new File(subPath, "cvModel3").getPath - cvModel.save(savingPathWithSubModels, persistSubModels = true) + cvModel.save(savingPathWithSubModels) val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels) - assert(cvModel3.subModels.isDefined && cvModel3.subModels.get.length == numFolds) - cvModel3.subModels.get.foreach(array => assert(array.length == lrParamMaps.length)) + assert(cvModel3.hasSubModels && cvModel3.subModels.length == numFolds) + cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length)) + + val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath + cvModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels) + assert(!cvModel2.hasSubModels) for (i <- 0 until numFolds) { for (j <- 0 until lrParamMaps.length) { - assert(cvModel.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid === - cvModel3.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid) + assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid === + cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index d9402ece132e..96a98a9e10ba 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -192,8 +192,7 @@ class TrainValidationSplitSuite .addGrid(lr.maxIter, Array(0, 3)) .build() val eval = new BinaryClassificationEvaluator - val subdirName = Identifiable.randomUID("testSubModels") - val subPath = new File(tempDir, subdirName) + val subPath = new File(tempDir, "testTrainValidationSplitSubModels") val tvs = new TrainValidationSplit() .setEstimator(lr) @@ -204,21 +203,22 @@ class TrainValidationSplitSuite val tvsModel = tvs.fit(dataset) - assert(tvsModel.subModels.isDefined && tvsModel.subModels.get.length == lrParamMaps.length) - - val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath - tvsModel.save(savingPathWithoutSubModels) - val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) - assert(tvsModel2.subModels.isEmpty) + assert(tvsModel.hasSubModels && tvsModel.subModels.length == lrParamMaps.length) + // Test the default value for option "persistSubModel" to be "true" val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath - tvsModel.save(savingPathWithSubModels, persistSubModels = true) + tvsModel.save(savingPathWithSubModels) val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels) - assert(tvsModel3.subModels.isDefined && tvsModel3.subModels.get.length == lrParamMaps.length) + assert(tvsModel3.hasSubModels && tvsModel3.subModels.length == lrParamMaps.length) + + val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath + tvsModel.write.option("persistSubModels", "false").save(savingPathWithoutSubModels) + val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels) + assert(!tvsModel2.hasSubModels) for (i <- 0 until lrParamMaps.length) { - assert(tvsModel.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid === - tvsModel3.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid) + assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === + tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid) } } From 2a83fb5071aa879dbae7791eb7c956793246f8df Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Fri, 3 Nov 2017 17:32:32 +0800 Subject: [PATCH 06/10] fix style --- .../main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 3 +-- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index d4d3661ee624..885fae301942 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.tuning -import java.io.IOException -import java.util.{Locale, List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index a34895be0feb..f72c584a1699 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,8 +17,7 @@ package org.apache.spark.ml.tuning -import java.io.IOException -import java.util.{Locale, List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future From f2ef609ab462f1c5ce3044f2b3a3e5389a77949e Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Mon, 6 Nov 2017 11:10:13 +0800 Subject: [PATCH 07/10] address comments from joseph --- .../spark/ml/tuning/CrossValidator.scala | 96 +++++++++++-------- .../ml/tuning/TrainValidationSplit.scala | 86 ++++++++++------- .../org/apache/spark/ml/util/ReadWrite.scala | 17 +++- .../spark/ml/tuning/CrossValidatorSuite.scala | 5 + .../ml/tuning/TrainValidationSplitSuite.scala | 5 + 5 files changed, 131 insertions(+), 78 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 885fae301942..05f053e26532 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tuning -import java.util.{List => JList, Locale} +import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -102,7 +102,18 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) - /** @group expertSetParam */ + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} + * for more information. + * + * @group expertSetParam + */ @Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @@ -262,15 +273,26 @@ class CrossValidatorModel private[ml] ( private var _subModels: Option[Array[Array[Model[_]]]] = None - @Since("2.3.0") private[tuning] def setSubModels(subModels: Option[Array[Array[Model[_]]]]) : CrossValidatorModel = { _subModels = subModels this } + /** + * @return submodels represented in two dimension array. The index of outer array is the + * fold index, and the index of inner array corresponds to the ordering of + * estimatorParamsMaps + * + * Note: If submodels not available, exception will be thrown. only when we set collectSubModels + * Param before fitting, submodels will be available. + */ @Since("2.3.0") - def subModels: Array[Array[Model[_]]] = _subModels.get + def subModels: Array[Array[Model[_]]] = { + require(_subModels.isDefined, "submodels not available, set collectSubModels param before " + + "fitting will address this issue.") + _subModels.get + } @Since("2.3.0") def hasSubModels: Boolean = _subModels.isDefined @@ -297,24 +319,17 @@ class CrossValidatorModel private[ml] ( } @Since("1.6.0") - override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this) + override def write: CrossValidatorModel.CrossValidatorModelWriter = { + new CrossValidatorModel.CrossValidatorModelWriter(this) + } } @Since("1.6.0") object CrossValidatorModel extends MLReadable[CrossValidatorModel] { - private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) = { - subModels.map { subModels => - val numFolds = subModels.length - val numParamMaps = subModels(0).length - val copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null)) - for (i <- 0 until numFolds) { - for (j <- 0 until numParamMaps) { - copiedSubModels(i)(j) = subModels(i)(j).copy(ParamMap.empty).asInstanceOf[Model[_]] - } - } - copiedSubModels - } + private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) + : Option[Array[Array[Model[_]]]] = { + subModels.map(_.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]]))) } @Since("1.6.0") @@ -323,39 +338,40 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { @Since("1.6.0") override def load(path: String): CrossValidatorModel = super.load(path) - private[CrossValidatorModel] + /** + * Writer for CrossValidatorModel. + * @param instance CrossValidatorModel instance used to construct the writer + * + * Options: + * CrossValidatorModelWriter support an option "persistSubModels", available value is + * "true" or "false". If you set collectSubModels param before fitting, and then you can set + * the option "persistSubModels" to be "true" and the submodels will be persisted. + * The default value of "persistSubModels" will be "true", if you set collectSubModels + * param before fitting, but if you do not set collectSubModels param before fitting, setting + * "persistSubModels" will cause exception. + */ + @Since("2.3.0") class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { ValidatorParams.validateParams(instance) - protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false - - /** - * Extra options for CrossValidatorModelWriter, current support "persistSubModels". - * if sub models exsit, the default value for option "persistSubModels" is "true". - */ - @Since("2.3.0") - override def option(key: String, value: String): this.type = { - key.toLowerCase(Locale.ROOT) match { - case "persistsubmodels" => shouldPersistSubModels = value.toBoolean - case _ => throw new IllegalArgumentException( - s"Illegal option ${key} for CrossValidatorModelWriter") - } - this - } - override protected def saveImpl(path: String): Unit = { + val persistSubModels = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false").toBoolean + import org.json4s.JsonDSL._ val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ - ("shouldPersistSubModels" -> shouldPersistSubModels) + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) - if (shouldPersistSubModels) { - require(instance.hasSubModels, "Cannot get sub models to persist.") + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") val subModelsPath = new Path(path, "subModels") for (splitIndex <- 0 until instance.getNumFolds) { - val splitPath = new Path(subModelsPath, splitIndex.toString) + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { val modelPath = new Path(splitPath, paramIndex.toString).toString instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath) @@ -379,14 +395,14 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray - val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean] + val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean] val subModels: Option[Array[Array[Model[_]]]] = if (shouldPersistSubModels) { val subModelsPath = new Path(path, "subModels") val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( estimatorParamMaps.length)(null)) for (splitIndex <- 0 until numFolds) { - val splitPath = new Path(subModelsPath, splitIndex.toString) + val splitPath = new Path(subModelsPath, s"fold${splitIndex.toString}") for (paramIndex <- 0 until estimatorParamMaps.length) { val modelPath = new Path(splitPath, paramIndex.toString).toString _subModels(splitIndex)(paramIndex) = diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index f72c584a1699..7bb8aad77f14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -101,8 +101,18 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.3.0") def setParallelism(value: Int): this.type = set(parallelism, value) - /** @group expertSetParam */ - @Since("2.3.0") + /** + * Whether to collect submodels when fitting. If set, we can get submodels from + * the returned model. + * + * Note: If set this param, when you save the returned model, you can set an option + * "persistSubModels" to be "true" before saving, in order to save these submodels. + * You can check documents of + * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} + * for more information. + * + * @group expertSetParam + */@Since("2.3.0") def setCollectSubModels(value: Boolean): this.type = set(collectSubModels, value) @Since("2.0.0") @@ -192,7 +202,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } @Since("2.0.0") - override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this) + override def write: TrainValidationSplit.TrainValidationSplitWriter = { + new TrainValidationSplit.TrainValidationSplitWriter(this) + } } @Since("2.0.0") @@ -255,15 +267,25 @@ class TrainValidationSplitModel private[ml] ( private var _subModels: Option[Array[Model[_]]] = None - @Since("2.3.0") private[tuning] def setSubModels(subModels: Option[Array[Model[_]]]) : TrainValidationSplitModel = { _subModels = subModels this } + /** + * @return submodels represented in array. The index of array corresponds to the ordering of + * estimatorParamsMaps + * + * Note: If submodels not available, exception will be thrown. only when we set collectSubModels + * Param before fitting, submodels will be available. + */ @Since("2.3.0") - def subModels: Array[Model[_]] = _subModels.get + def subModels: Array[Model[_]] = { + require(_subModels.isDefined, "submodels not available, set collectSubModels param before " + + "fitting will address this issue.") + _subModels.get + } @Since("2.3.0") def hasSubModels: Boolean = _subModels.isDefined @@ -296,15 +318,9 @@ class TrainValidationSplitModel private[ml] ( @Since("2.0.0") object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { - private[TrainValidationSplitModel] def copySubModels(subModels: Option[Array[Model[_]]]) = { - subModels.map { subModels => - val numParamMaps = subModels.length - val copiedSubModels = Array.fill[Model[_]](numParamMaps)(null) - for (i <- 0 until numParamMaps) { - copiedSubModels(i) = subModels(i).copy(ParamMap.empty).asInstanceOf[Model[_]] - } - copiedSubModels - } + private[TrainValidationSplitModel] def copySubModels(subModels: Option[Array[Model[_]]]) + : Option[Array[Model[_]]] = { + subModels.map(_.map(_.copy(ParamMap.empty).asInstanceOf[Model[_]])) } @Since("2.0.0") @@ -313,36 +329,36 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { @Since("2.0.0") override def load(path: String): TrainValidationSplitModel = super.load(path) - private[TrainValidationSplitModel] + /** + * Writer for TrainValidationSplitModel. + * @param instance TrainValidationSplitModel instance used to construct the writer + * + * Options: + * TrainValidationSplitModel support an option "persistSubModels", available value is + * "true" or "false". If you set collectSubModels param before fitting, and then you can set + * the option "persistSubModels" to be "true" and the submodels will be persisted. + * The default value of "persistSubModels" will be "true", if you set collectSubModels + * param before fitting, but if you do not set collectSubModels param before fitting, setting + * "persistSubModels" will cause exception. + */ class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter { ValidatorParams.validateParams(instance) - protected var shouldPersistSubModels: Boolean = if (instance.hasSubModels) true else false - - /** - * Extra options for TrainValidationSplitModelWriter, current support "persistSubModels". - * if sub models exsit, the default value for option "persistSubModels" is "true". - */ - @Since("2.3.0") - override def option(key: String, value: String): this.type = { - key.toLowerCase(Locale.ROOT) match { - case "persistsubmodels" => shouldPersistSubModels = value.toBoolean - case _ => throw new IllegalArgumentException( - s"Illegal option ${key} for TrainValidationSplitModelWriter") - } - this - } - override protected def saveImpl(path: String): Unit = { + val persistSubModels = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false").toBoolean + import org.json4s.JsonDSL._ val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~ - ("shouldPersistSubModels" -> shouldPersistSubModels) + ("persistSubModels" -> persistSubModels) ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata)) val bestModelPath = new Path(path, "bestModel").toString instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath) - if (shouldPersistSubModels) { - require(instance.hasSubModels, "Cannot get sub models to persist.") + if (persistSubModels) { + require(instance.hasSubModels, "When persisting tuning models, you can only set " + + "persistSubModels to true if the tuning was done with collectSubModels set to true. " + + "To save the sub-models, try rerunning fitting with collectSubModels set to true.") val subModelsPath = new Path(path, "subModels") for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) { val modelPath = new Path(subModelsPath, paramIndex.toString).toString @@ -365,7 +381,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray - val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean] + val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean] val subModels: Option[Array[Model[_]]] = if (shouldPersistSubModels) { val subModelsPath = new Path(path, "subModels") diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index dafc4326bdae..edc39781c26c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,6 +18,9 @@ package org.apache.spark.ml.util import java.io.IOException +import java.util.Locale + +import scala.collection.mutable import org.apache.hadoop.fs.Path import org.json4s._ @@ -108,11 +111,19 @@ abstract class MLWriter extends BaseReadWrite with Logging { protected def saveImpl(path: String): Unit /** - * `option()` handles extra options. If subclasses need to support extra options, override this - * method. + * Map store extra options for this writer. + */ + protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() + + /** + * `option()` handles extra options. */ @Since("2.3.0") - def option(key: String, value: String): this.type = this + def option(key: String, value: String): this.type = { + require(key != null && !key.isEmpty) + optionMap.put(key.toLowerCase(Locale.ROOT), value) + this + } /** * Overwrites if the output path already exists. diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 4a6a4ac509fe..24b608eafeca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -233,6 +233,11 @@ class CrossValidatorSuite cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid) } } + + val savingPathTestingIllegalParam = new File(subPath, "cvModel4").getPath + intercept[IllegalArgumentException] { + cvModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } } test("read/write: CrossValidator with nested estimator") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 96a98a9e10ba..2e5ece9bb4f4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -220,6 +220,11 @@ class TrainValidationSplitSuite assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid === tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid) } + + val savingPathTestingIllegalParam = new File(subPath, "tvsModel4").getPath + intercept[IllegalArgumentException] { + tvsModel2.write.option("persistSubModels", "true").save(savingPathTestingIllegalParam) + } } test("read/write: TrainValidationSplit with nested estimator") { From 7bacfcac2e20552bb4557614ba477cf776bdf8af Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 7 Nov 2017 14:08:14 +0800 Subject: [PATCH 08/10] address comments from joseph --- .../spark/ml/tuning/CrossValidator.scala | 38 ++++++++++--------- .../ml/tuning/TrainValidationSplit.scala | 38 ++++++++++--------- .../org/apache/spark/ml/util/ReadWrite.scala | 5 ++- 3 files changed, 45 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 05f053e26532..3db70a9d5877 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tuning -import java.util.{List => JList} +import java.util.{List => JList, Locale} import scala.collection.JavaConverters._ import scala.concurrent.Future @@ -282,15 +282,14 @@ class CrossValidatorModel private[ml] ( /** * @return submodels represented in two dimension array. The index of outer array is the * fold index, and the index of inner array corresponds to the ordering of - * estimatorParamsMaps - * - * Note: If submodels not available, exception will be thrown. only when we set collectSubModels - * Param before fitting, submodels will be available. + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. */ @Since("2.3.0") def subModels: Array[Array[Model[_]]] = { - require(_subModels.isDefined, "submodels not available, set collectSubModels param before " + - "fitting will address this issue.") + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") _subModels.get } @@ -342,22 +341,27 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { * Writer for CrossValidatorModel. * @param instance CrossValidatorModel instance used to construct the writer * - * Options: - * CrossValidatorModelWriter support an option "persistSubModels", available value is - * "true" or "false". If you set collectSubModels param before fitting, and then you can set - * the option "persistSubModels" to be "true" and the submodels will be persisted. - * The default value of "persistSubModels" will be "true", if you set collectSubModels - * param before fitting, but if you do not set collectSubModels param before fitting, setting - * "persistSubModels" will cause exception. + * CrossValidatorModelWriter supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. */ @Since("2.3.0") - class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter { + final class CrossValidatorModelWriter private[tuning] ( + instance: CrossValidatorModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { - val persistSubModels = optionMap.getOrElse("persistsubmodels", - if (instance.hasSubModels) "true" else "false").toBoolean + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean import org.json4s.JsonDSL._ val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toSeq) ~ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 7bb8aad77f14..7f01c4beb2c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -108,7 +108,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St * Note: If set this param, when you save the returned model, you can set an option * "persistSubModels" to be "true" before saving, in order to save these submodels. * You can check documents of - * {@link org.apache.spark.ml.tuning.CrossValidatorModel.CrossValidatorModelWriter} + * {@link org.apache.spark.ml.tuning.TrainValidationSplitModel.TrainValidationSplitModelWriter} * for more information. * * @group expertSetParam @@ -275,15 +275,14 @@ class TrainValidationSplitModel private[ml] ( /** * @return submodels represented in array. The index of array corresponds to the ordering of - * estimatorParamsMaps - * - * Note: If submodels not available, exception will be thrown. only when we set collectSubModels - * Param before fitting, submodels will be available. + * estimatorParamMaps + * @throws IllegalArgumentException if subModels are not available. To retrieve subModels, + * make sure to set collectSubModels to true before fitting. */ @Since("2.3.0") def subModels: Array[Model[_]] = { - require(_subModels.isDefined, "submodels not available, set collectSubModels param before " + - "fitting will address this issue.") + require(_subModels.isDefined, "subModels not available, To retrieve subModels, make sure " + + "to set collectSubModels to true before fitting.") _subModels.get } @@ -333,21 +332,26 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { * Writer for TrainValidationSplitModel. * @param instance TrainValidationSplitModel instance used to construct the writer * - * Options: - * TrainValidationSplitModel support an option "persistSubModels", available value is - * "true" or "false". If you set collectSubModels param before fitting, and then you can set - * the option "persistSubModels" to be "true" and the submodels will be persisted. - * The default value of "persistSubModels" will be "true", if you set collectSubModels - * param before fitting, but if you do not set collectSubModels param before fitting, setting - * "persistSubModels" will cause exception. + * TrainValidationSplitModel supports an option "persistSubModels", with possible values + * "true" or "false". If you set the collectSubModels Param before fitting, then you can + * set "persistSubModels" to "true" in order to persist the subModels. By default, + * "persistSubModels" will be "true" when subModels are available and "false" otherwise. + * If subModels are not available, then setting "persistSubModels" to "true" will cause + * an exception. */ - class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter { + final class TrainValidationSplitModelWriter private[tuning] ( + instance: TrainValidationSplitModel) extends MLWriter { ValidatorParams.validateParams(instance) override protected def saveImpl(path: String): Unit = { - val persistSubModels = optionMap.getOrElse("persistsubmodels", - if (instance.hasSubModels) "true" else "false").toBoolean + val persistSubModelsParam = optionMap.getOrElse("persistsubmodels", + if (instance.hasSubModels) "true" else "false") + + require(Array("true", "false").contains(persistSubModelsParam.toLowerCase(Locale.ROOT)), + s"persistSubModels option value ${persistSubModelsParam} is invalid, the possible " + + "values are \"true\" or \"false\"") + val persistSubModels = persistSubModelsParam.toBoolean import org.json4s.JsonDSL._ val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toSeq) ~ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index edc39781c26c..a61690780096 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -111,12 +111,13 @@ abstract class MLWriter extends BaseReadWrite with Logging { protected def saveImpl(path: String): Unit /** - * Map store extra options for this writer. + * Map to store extra options for this writer. */ protected val optionMap: mutable.Map[String, String] = new mutable.HashMap[String, String]() /** - * `option()` handles extra options. + * Adds an option to the underlying MLWriter. See the documentation for the specific model's + * writer for possible options. The option name (key) is case-insensitive. */ @Since("2.3.0") def option(key: String, value: String): this.type = { From 654e4d580889dcd2fcf7c0bea2060349190faaac Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 7 Nov 2017 16:54:28 +0800 Subject: [PATCH 09/10] fix mima --- project/MimaExcludes.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535..a5f03652321c 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -68,7 +68,11 @@ object MimaExcludes { // [SPARK-14280] Support Scala 2.12 ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), + + // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), + ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter") ) // Exclude rules for 2.2.x From 7e997da44157a9807de9c8fe8e7d2e5b66b6bfb1 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Tue, 14 Nov 2017 13:10:26 +0800 Subject: [PATCH 10/10] address minor issues --- .../apache/spark/ml/tuning/CrossValidator.scala | 5 +++-- .../spark/ml/tuning/TrainValidationSplit.scala | 14 ++++++++------ .../spark/ml/tuning/CrossValidatorSuite.scala | 1 - 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 3db70a9d5877..1682ca91bf83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -399,9 +399,10 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray - val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean] + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) - val subModels: Option[Array[Array[Model[_]]]] = if (shouldPersistSubModels) { + val subModels: Option[Array[Array[Model[_]]]] = if (persistSubModels) { val subModelsPath = new Path(path, "subModels") val _subModels = Array.fill(numFolds)(Array.fill[Model[_]]( estimatorParamMaps.length)(null)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 7f01c4beb2c6..c73bd1847547 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -202,9 +202,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } @Since("2.0.0") - override def write: TrainValidationSplit.TrainValidationSplitWriter = { - new TrainValidationSplit.TrainValidationSplitWriter(this) - } + override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this) } @Since("2.0.0") @@ -311,7 +309,9 @@ class TrainValidationSplitModel private[ml] ( } @Since("2.0.0") - override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + override def write: TrainValidationSplitModel.TrainValidationSplitModelWriter = { + new TrainValidationSplitModel.TrainValidationSplitModelWriter(this) + } } @Since("2.0.0") @@ -339,6 +339,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { * If subModels are not available, then setting "persistSubModels" to "true" will cause * an exception. */ + @Since("2.3.0") final class TrainValidationSplitModelWriter private[tuning] ( instance: TrainValidationSplitModel) extends MLWriter { @@ -385,9 +386,10 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray - val shouldPersistSubModels = (metadata.metadata \ "persistSubModels").extract[Boolean] + val persistSubModels = (metadata.metadata \ "persistSubModels") + .extractOrElse[Boolean](false) - val subModels: Option[Array[Model[_]]] = if (shouldPersistSubModels) { + val subModels: Option[Array[Model[_]]] = if (persistSubModels) { val subModelsPath = new Path(path, "subModels") val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null) for (paramIndex <- 0 until estimatorParamMaps.length) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 24b608eafeca..d8a7ee6f5693 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -200,7 +200,6 @@ class CrossValidatorSuite val eval = new BinaryClassificationEvaluator val numFolds = 3 val subPath = new File(tempDir, "testCrossValidatorSubModels") - val persistSubModelsPath = new File(subPath, "subModels").toString val cv = new CrossValidator() .setEstimator(lr)