From 62ab5d00d5df258aa8fdce84c9c487b953d4aa0b Mon Sep 17 00:00:00 2001 From: Aman Omer Date: Sat, 9 Nov 2019 23:14:38 +0530 Subject: [PATCH] Initial commit --- .../scala/org/apache/spark/ml/tuning/CrossValidator.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e60a14f976a5..6c00b569aa06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -141,8 +141,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))) } else None + val inputRDD = dataset.toDF.rdd + inputRDD.persist() // Compute metrics for each model over each split - val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + val splits = MLUtils.kFold(inputRDD, $(numFolds), $(seed)) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache()