diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2012d6ca8b5e..fc2c959d0855 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -112,16 +112,16 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val validationDataset = sparkSession.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") - val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] - trainingDataset.unpersist() var i = 0 while (i < numModels) { + val model = est.fit(trainingDataset, epm(i)).asInstanceOf[Model[_]] // TODO: duplicate evaluator to take extra params from input - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) + val metric = eval.evaluate(model.transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric i += 1 } + trainingDataset.unpersist() validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index db7c9d13d301..8e4db62824e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -108,16 +108,16 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // multi-model training logDebug(s"Train split with multiple sets of parameters.") - val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] - trainingDataset.unpersist() var i = 0 while (i < numModels) { + val model = est.fit(trainingDataset, epm(i)).asInstanceOf[Model[_]] // TODO: duplicate evaluator to take extra params from input - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) + val metric = eval.evaluate(model.transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric i += 1 } + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}")