Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml

import org.apache.spark.annotation.Since
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -71,7 +72,7 @@ private[ml] trait PredictorParams extends Params
val w = this match {
case p: HasWeightCol =>
if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
col($(p.weightCol)).cast(DoubleType)
checkNonNegativeWeight((col($(p.weightCol)).cast(DoubleType)))
} else {
lit(1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.json4s.DefaultFormats

import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
Expand Down Expand Up @@ -179,7 +180,7 @@ class NaiveBayes @Since("1.5.0") (
}

val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
Expand Down Expand Up @@ -259,7 +260,7 @@ class NaiveBayes @Since("1.5.0") (
import spark.implicits._

val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -280,7 +281,7 @@ class BisectingKMeans @Since("2.0.0") (

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -417,7 +418,7 @@ class GaussianMixture @Since("2.0.0") (
instr.logNumFeatures(numFeatures)

val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -336,7 +337,7 @@ class KMeans @Since("1.5.0") (

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Since
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -131,7 +132,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
col($(rawPredictionCol)),
col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
else col($(weightCol)).cast(DoubleType)).rdd.map {
else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))).rdd.map {
case Row(rawPrediction: Vector, label: Double, weight: Double) =>
(rawPrediction(1), label, weight)
case Row(rawPrediction: Double, label: Double, weight: Double) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Since
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util._
Expand Down Expand Up @@ -139,7 +140,7 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str
} else {
dataset.select(col($(predictionCol)),
vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata),
col(weightColName).cast(DoubleType))
checkNonNegativeWeight(col(weightColName).cast(DoubleType)))
}

val metrics = new ClusteringMetrics(df)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette {
(featureSum: DenseVector, squaredNormSum: Double, weightSum: Double),
(features, squaredNorm, weight)
) =>
require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
BLAS.axpy(weight, features, featureSum)
(featureSum, squaredNormSum + squaredNorm * weight, weightSum + weight)
},
Expand Down Expand Up @@ -503,7 +502,6 @@ private[evaluation] object CosineSilhouette extends Silhouette {
seqOp = {
case ((normalizedFeaturesSum: DenseVector, weightSum: Double),
(normalizedFeatures, weight)) =>
require(weight >= 0.0, s"illegal weight value: $weight. weight must be >= 0.0.")
BLAS.axpy(weight, normalizedFeatures, normalizedFeaturesSum)
(normalizedFeaturesSum, weightSum + weight)
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Since
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -186,7 +187,7 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
SchemaUtils.checkNumericType(schema, $(labelCol))

val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
checkNonNegativeWeight(col($(weightCol)).cast(DoubleType))
} else {
lit(1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Since
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
Expand Down Expand Up @@ -122,7 +123,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui

val predictionAndLabelsWithWeights = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0)
else checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)))
.rdd
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
Expand Down
6 changes: 6 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,10 @@ object functions {
)
}
}

private[ml] def checkNonNegativeWeight = udf {
value: Double =>
require(value >= 0, s"illegal weight value: $value. weight must be >= 0.0.")
value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.optim._
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -399,7 +400,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
"GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " +
"set to false. To fit a model with 0 features, fitIntercept must be set to true." )

val w = if (!hasWeightCol) lit(1.0) else col($(weightCol))
val w = if (!hasWeightCol) lit(1.0) else checkNonNegativeWeight(col($(weightCol)))
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)

val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.functions.checkNonNegativeWeight
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -87,11 +88,11 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
} else {
col($(featuresCol))
}
val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0)
val w =
if (hasWeightCol) checkNonNegativeWeight(col($(weightCol)).cast(DoubleType)) else lit(1.0)

dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
case Row(label: Double, feature: Double, weight: Double) =>
(label, feature, weight)
case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight)
}
}

Expand Down