diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 64ecc1ebda589..214973f7396dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -341,7 +341,7 @@ class GaussianMixture @Since("2.0.0") ( val sc = dataset.sparkSession.sparkContext val numClusters = $(k) - val instances: RDD[Vector] = dataset + val instances = dataset .select(DatasetUtils.columnToVector(dataset, getFeaturesCol)).rdd.map { case Row(features: Vector) => features }.cache() @@ -416,6 +416,7 @@ class GaussianMixture @Since("2.0.0") ( iter += 1 } + instances.unpersist(false) val gaussianDists = gaussians.map { case (mean, covVec) => val cov = GaussianMixture.unpackUpperTriangularMatrix(numFeatures, covVec.values) new MultivariateGaussian(mean, cov)