diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index febeba7e13fcb..e0b128e369816 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -71,7 +72,7 @@ private[ml] trait PredictorParams extends Params val w = this match { case p: HasWeightCol => if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) { - col($(p.weightCol)).cast(DoubleType) + checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType))) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5459a0fab9135..e65295dbdaf55 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -22,6 +22,7 @@ import org.json4s.DefaultFormats import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasWeightCol @@ -179,7 +180,7 @@ class NaiveBayes @Since("1.5.0") ( } val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } @@ -259,7 +260,7 @@ class NaiveBayes @Since("1.5.0") ( import spark.implicits._ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 6c7112b80569f..b09f11dcfe156 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -280,7 +281,7 @@ class BisectingKMeans @Since("2.0.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 6d4137b638dcc..18fd220b4ca9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ @@ -417,7 +418,7 @@ class GaussianMixture @Since("2.0.0") ( instr.logNumFeatures(numFeatures) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a42c920e24987..806015b633c23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, PipelineStage} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") ( val handlePersistence = dataset.storageLevel == StorageLevel.NONE val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index fac4d92b1810c..52be22f714981 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -131,7 +132,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va col($(rawPredictionCol)), col($(labelCol)).cast(DoubleType), if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) - else col($(weightCol)).cast(DoubleType)).rdd.map { + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map { case Row(rawPrediction: Vector, label: Double, weight: Double) => (rawPrediction(1), label, weight) case Row(rawPrediction: Double, label: Double, weight: Double) => diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala index 19790fd270619..fa2c25a5912a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util._ @@ -139,7 +140,7 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str } else { dataset.select(col($(predictionCol)), vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata), - col(weightColName).cast(DoubleType)) + checkNonNegativeWeight(col(weightColName).cast(DoubleType))) } val metrics = new ClusteringMetrics(df) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala index 8bf4ee1ecadfb..a785d063f1476 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringMetrics.scala @@ -300,7 +300,6 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette { (featureSum: DenseVector, squaredNormSum: Double, weightSum: Double), (features, squaredNorm, weight) ) => - require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.") BLAS.axpy(weight, features, featureSum) (featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight) }, @@ -503,7 +502,6 @@ private[evaluation] object CosineSilhouette extends Silhouette { seqOp = { case ((normalizedFeaturesSum: DenseVector, weightSum: Double), (normalizedFeatures, weight)) => - require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.") BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum) (normalizedFeaturesSum, weightSum + weight) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index ad1b70915e157..3d77792c4fc88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -186,7 +187,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid SchemaUtils.checkNumericType(schema, $(labelCol)) val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) { - col($(weightCol)).cast(DoubleType) + checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) } else { lit(1.0) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index aca017762deca..f0b7c345c3285 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.Since +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} @@ -122,7 +123,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui val predictionAndLabelsWithWeights = dataset .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType), - if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))) + if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) + else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))) .rdd .map { case Row(prediction: Double, label: Double, weight: Double) => (prediction, label, weight) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 0f03231079866..a0b6d11a46be9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -71,4 +71,10 @@ object functions { ) } } + + private[ml] def checkNonNegativeWeight = udf { + value: Double => + require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.") + value + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index fa41a98749f32..0ee895a95a288 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.{Instance, OffsetInstance} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors} import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ @@ -399,7 +400,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val "GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " + "set to false. To fit a model with 0 features, fitIntercept must be set to true." ) - val w = if (!hasWeightCol) lit(1.0) else col($(weightCol)) + val w = if (!hasWeightCol) lit(1.0) else checkNonNegativeWeight(col($(weightCol))) val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType) val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index fe4de57de60f2..ec2640e9ef225 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.functions.checkNonNegativeWeight import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -87,11 +88,11 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { col($(featuresCol)) } - val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0) + val w = + if (hasWeightCol) checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0) dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { - case Row(label: Double, feature: Double, weight: Double) => - (label, feature, weight) + case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) } }