diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala index 3dea244c77226..3035688709301 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala @@ -37,13 +37,18 @@ class ClusteringMetrics private[spark](dataset: Dataset[_]) { def getDistanceMeasure: String = distanceMeasure - def setDistanceMeasure(value: String) : Unit = distanceMeasure = value + def setDistanceMeasure(value: String) : this.type = { + require(value.equalsIgnoreCase("squaredEuclidean") || + value.equalsIgnoreCase("cosine")) + distanceMeasure = value + this + } /** * Returns the silhouette score */ @Since("3.1.0") - lazy val silhouette: Double = { + def silhouette(): Double = { val columns = dataset.columns.toSeq if (distanceMeasure.equalsIgnoreCase("squaredEuclidean")) { SquaredEuclideanSilhouette.computeSilhouetteScore(