From a7667e72d78f679b9693e22742e8a624b6348fd2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 25 Jul 2017 14:41:17 -0700 Subject: [PATCH] memory optimization --- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 6 +++--- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2012d6ca8b5e..fc2c959d0855 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -112,16 +112,16 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val validationDataset = sparkSession.createDataFrame(validation, schema).cache() // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") - val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] - trainingDataset.unpersist() var i = 0 while (i < numModels) { + val model = est.fit(trainingDataset, epm(i)).asInstanceOf[Model[_]] // TODO: duplicate evaluator to take extra params from input - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) + val metric = eval.evaluate(model.transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric i += 1 } + trainingDataset.unpersist() validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index db7c9d13d301..8e4db62824e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -108,16 +108,16 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // multi-model training logDebug(s"Train split with multiple sets of parameters.") - val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] - trainingDataset.unpersist() var i = 0 while (i < numModels) { + val model = est.fit(trainingDataset, epm(i)).asInstanceOf[Model[_]] // TODO: duplicate evaluator to take extra params from input - val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) + val metric = eval.evaluate(model.transform(validationDataset, epm(i))) logDebug(s"Got metric $metric for model trained with ${epm(i)}.") metrics(i) += metric i += 1 } + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}")