diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index bcf9b7c0426cd..8b6ede3bb362c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -961,21 +961,27 @@ class LinearRegressionSummary private[regression] ( private val privateModel: LinearRegressionModel, private val diagInvAtWA: Array[Double]) extends Serializable { - @transient private val metrics = new RegressionMetrics( - predictions - .select(col(predictionCol), col(labelCol).cast(DoubleType)) - .rdd - .map { case Row(pred: Double, label: Double) => (pred, label) }, - !privateModel.getFitIntercept) + @transient private val metrics = { + val weightCol = + if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) { + lit(1.0) + } else { + col(privateModel.getWeightCol).cast(DoubleType) + } + + new RegressionMetrics( + predictions + .select(col(predictionCol), col(labelCol).cast(DoubleType), weightCol) + .rdd + .map { case Row(pred: Double, label: Double, weight: Double) => (pred, label, weight) }, + !privateModel.getFitIntercept) + } /** * Returns the explained variance regression score. * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: * Wikipedia explain variation - * - * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -983,9 +989,6 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. - * - * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -993,9 +996,6 @@ class LinearRegressionSummary private[regression] ( /** * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. - * - * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -1003,9 +1003,6 @@ class LinearRegressionSummary private[regression] ( /** * Returns the root mean squared error, which is defined as the square root of * the mean squared error. - * - * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -1014,9 +1011,6 @@ class LinearRegressionSummary private[regression] ( * Returns R^2^, the coefficient of determination. * Reference: * Wikipedia coefficient of determination - * - * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. - * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 @@ -1025,9 +1019,6 @@ class LinearRegressionSummary private[regression] ( * Returns Adjusted R^2^, the adjusted coefficient of determination. * Reference: * Wikipedia coefficient of determination - * - * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`. - * This will change in later Spark versions. */ @Since("2.3.0") val r2adj: Double = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index df9a66b49fe48..c4a94ff2d6f44 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { @@ -899,6 +900,46 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe } } + test("linear regression model training summary with weighted samples") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = new LinearRegression().setSolver(solver) + val trainer2 = new LinearRegression().setSolver(solver).setWeightCol("weight") + + Seq(0.25, 1.0, 10.0, 50.00).foreach { w => + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature.withColumn("weight", lit(w))) + assert(model1.summary.explainedVariance ~== model2.summary.explainedVariance relTol 1e-6) + assert(model1.summary.meanAbsoluteError ~== model2.summary.meanAbsoluteError relTol 1e-6) + assert(model1.summary.meanSquaredError ~== model2.summary.meanSquaredError relTol 1e-6) + assert(model1.summary.rootMeanSquaredError ~== + model2.summary.rootMeanSquaredError relTol 1e-6) + assert(model1.summary.r2 ~== model2.summary.r2 relTol 1e-6) + assert(model1.summary.r2adj ~== model2.summary.r2adj relTol 1e-6) + } + } + } + + test("linear regression model testset evaluation summary with weighted samples") { + Seq("auto", "l-bfgs", "normal").foreach { solver => + val trainer1 = new LinearRegression().setSolver(solver) + val trainer2 = new LinearRegression().setSolver(solver).setWeightCol("weight") + + Seq(0.25, 1.0, 10.0, 50.00).foreach { w => + val model1 = trainer1.fit(datasetWithDenseFeature) + val model2 = trainer2.fit(datasetWithDenseFeature.withColumn("weight", lit(w))) + val testSummary1 = model1.evaluate(datasetWithDenseFeature) + val testSummary2 = model2.evaluate(datasetWithDenseFeature.withColumn("weight", lit(w))) + assert(testSummary1.explainedVariance ~== testSummary2.explainedVariance relTol 1e-6) + assert(testSummary1.meanAbsoluteError ~== testSummary2.meanAbsoluteError relTol 1e-6) + assert(testSummary1.meanSquaredError ~== testSummary2.meanSquaredError relTol 1e-6) + assert(testSummary1.rootMeanSquaredError ~== + testSummary2.rootMeanSquaredError relTol 1e-6) + assert(testSummary1.r2 ~== testSummary2.r2 relTol 1e-6) + assert(testSummary1.r2adj ~== testSummary2.r2adj relTol 1e-6) + } + } + } + test("linear regression with weighted samples") { val sqlContext = spark.sqlContext import sqlContext.implicits._