diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2012d6ca8b5e..43608b3936bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -91,6 +91,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** + * Optional parameter. If set, all the trained models during cross validation will be + * saved in the specific path. By default the models will not be preserved. + * + * @group expertSetParam + */ + @Since("2.3.0") + def setModelPreservePath(value: String): this.type = set(modelPreservePath, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -113,15 +122,28 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] - trainingDataset.unpersist() + var i = 0 while (i < numModels) { // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + if (isDefined(modelPreservePath)) { + models(i) match { + case w: MLWritable => + // e.g. maxIter-5-regParam-0.001-split0-0.859 + val fileName = epm(i).toSeq.map(p => p.param.name + "-" + p.value).sorted + .mkString("-") + s"-split$splitIndex-${math.rint(metric * 1000) / 1000}" + w.save(new Path($(modelPreservePath), fileName).toString) + case _ => + // for third-party algorithms + logWarning(models(i).uid + " did not implement MLWritable. Serialization omitted.") + } + } metrics(i) += metric i += 1 } + trainingDataset.unpersist() validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index db7c9d13d301..3f9c2c1e5ebe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -87,6 +87,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** + * Optional parameter. If set, all the models fitted during the training will be saved + * under the specific directory path. By default the models will not be saved. + * + * @group expertSetParam + */ + @Since("2.3.0") + def setModelPreservePath(value: String): this.type = set(modelPreservePath, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema @@ -109,15 +118,27 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // multi-model training logDebug(s"Train split with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] - trainingDataset.unpersist() + var i = 0 while (i < numModels) { // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + if (isDefined(modelPreservePath)) { + models(i) match { + case w: MLWritable => + // e.g. maxIter-5-regParam-0.001-0.859 + val fileName = epm(i).toSeq.map(p => p.param.name + "-" + p.value).sorted + .mkString("-") + s"-${math.rint(metric * 1000) / 1000}" + w.save(new Path($(modelPreservePath), fileName).toString) + case _ => + logWarning(models(i).uid + " did not implement MLWritable. Serialization omitted.") + } + } metrics(i) += metric i += 1 } + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index d55eb14d0345..22cf8cc25bd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -22,6 +22,7 @@ import org.json4s.{DefaultFormats, _} import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} @@ -67,6 +68,21 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** @group getParam */ def getEvaluator: Evaluator = $(evaluator) + + /** + * Optional parameter. If set, all the models trained during the tuning grid search will be + * saved in the specific path. By default the models will not be preserved. + * + * @group expertParam + */ + val modelPreservePath: Param[String] = new Param(this, "modelPath", + "Optional parameter. If set, all the models fitted during the cross validation will be" + + " saved in the path") + + /** @group expertGetParam */ + @Since("2.3.0") + def getModelPreservePath: String = $(modelPreservePath) + protected def transformSchemaImpl(schema: StructType): StructType = { require($(estimatorParamMaps).nonEmpty, s"Validator requires non-empty estimatorParamMaps") val firstEstimatorParamMap = $(estimatorParamMaps).head diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 2b4e6b53e4f8..64e4b0ea467e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression @@ -31,6 +31,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -56,6 +57,7 @@ class CrossValidatorSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setNumFolds(3) + assert(!cv.isDefined(cv.modelPreservePath)) val cvModel = cv.fit(dataset) MLTestingUtils.checkCopyAndUids(cv, cvModel) @@ -242,6 +244,29 @@ class CrossValidatorSuite } } + test("cross validation with model path to save trained models") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 5)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + .setModelPreservePath(path) + try { + cv.fit(dataset) + assert(tempDir.list().length === 3 * 2 * 2) + } finally { + Utils.deleteRecursively(tempDir) + } + } + test("read/write: CrossValidatorModel") { val lr = new LogisticRegression() .setThreshold(0.6) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index a34f930aa11c..e7b754ff8d77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -53,6 +54,7 @@ class TrainValidationSplitSuite .setSeed(42L) val tvsModel = tvs.fit(dataset) val parent = tvsModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(!tvs.isDefined(tvs.modelPreservePath)) assert(tvs.getTrainRatio === 0.5) assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) @@ -117,6 +119,32 @@ class TrainValidationSplitSuite } } + test("train validation with modelPath to save trained models") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF() + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val tvs = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + .setSeed(42L) + .setModelPreservePath(path) + try { + tvs.fit(dataset) + assert(tempDir.list().length === 2 * 2) + } finally { + Utils.deleteRecursively(tempDir) + } + } + test("read/write: TrainValidationSplit") { val lr = new LogisticRegression().setMaxIter(3) val evaluator = new BinaryClassificationEvaluator()