diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index bcf9b7c0426cd..8b6ede3bb362c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -961,21 +961,27 @@ class LinearRegressionSummary private[regression] (
private val privateModel: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {
- @transient private val metrics = new RegressionMetrics(
- predictions
- .select(col(predictionCol), col(labelCol).cast(DoubleType))
- .rdd
- .map { case Row(pred: Double, label: Double) => (pred, label) },
- !privateModel.getFitIntercept)
+ @transient private val metrics = {
+ val weightCol =
+ if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) {
+ lit(1.0)
+ } else {
+ col(privateModel.getWeightCol).cast(DoubleType)
+ }
+
+ new RegressionMetrics(
+ predictions
+ .select(col(predictionCol), col(labelCol).cast(DoubleType), weightCol)
+ .rdd
+ .map { case Row(pred: Double, label: Double, weight: Double) => (pred, label, weight) },
+ !privateModel.getFitIntercept)
+ }
/**
* Returns the explained variance regression score.
* explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
* Reference:
* Wikipedia explain variation
- *
- * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
- * This will change in later Spark versions.
*/
@Since("1.5.0")
val explainedVariance: Double = metrics.explainedVariance
@@ -983,9 +989,6 @@ class LinearRegressionSummary private[regression] (
/**
* Returns the mean absolute error, which is a risk function corresponding to the
* expected value of the absolute error loss or l1-norm loss.
- *
- * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
- * This will change in later Spark versions.
*/
@Since("1.5.0")
val meanAbsoluteError: Double = metrics.meanAbsoluteError
@@ -993,9 +996,6 @@ class LinearRegressionSummary private[regression] (
/**
* Returns the mean squared error, which is a risk function corresponding to the
* expected value of the squared error loss or quadratic loss.
- *
- * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
- * This will change in later Spark versions.
*/
@Since("1.5.0")
val meanSquaredError: Double = metrics.meanSquaredError
@@ -1003,9 +1003,6 @@ class LinearRegressionSummary private[regression] (
/**
* Returns the root mean squared error, which is defined as the square root of
* the mean squared error.
- *
- * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
- * This will change in later Spark versions.
*/
@Since("1.5.0")
val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
@@ -1014,9 +1011,6 @@ class LinearRegressionSummary private[regression] (
* Returns R^2^, the coefficient of determination.
* Reference:
* Wikipedia coefficient of determination
- *
- * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
- * This will change in later Spark versions.
*/
@Since("1.5.0")
val r2: Double = metrics.r2
@@ -1025,9 +1019,6 @@ class LinearRegressionSummary private[regression] (
* Returns Adjusted R^2^, the adjusted coefficient of determination.
* Reference:
* Wikipedia coefficient of determination
- *
- * @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
- * This will change in later Spark versions.
*/
@Since("2.3.0")
val r2adj: Double = {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index df9a66b49fe48..c4a94ff2d6f44 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {
@@ -899,6 +900,46 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
}
}
+ test("linear regression model training summary with weighted samples") {
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = new LinearRegression().setSolver(solver)
+ val trainer2 = new LinearRegression().setSolver(solver).setWeightCol("weight")
+
+ Seq(0.25, 1.0, 10.0, 50.00).foreach { w =>
+ val model1 = trainer1.fit(datasetWithDenseFeature)
+ val model2 = trainer2.fit(datasetWithDenseFeature.withColumn("weight", lit(w)))
+ assert(model1.summary.explainedVariance ~== model2.summary.explainedVariance relTol 1e-6)
+ assert(model1.summary.meanAbsoluteError ~== model2.summary.meanAbsoluteError relTol 1e-6)
+ assert(model1.summary.meanSquaredError ~== model2.summary.meanSquaredError relTol 1e-6)
+ assert(model1.summary.rootMeanSquaredError ~==
+ model2.summary.rootMeanSquaredError relTol 1e-6)
+ assert(model1.summary.r2 ~== model2.summary.r2 relTol 1e-6)
+ assert(model1.summary.r2adj ~== model2.summary.r2adj relTol 1e-6)
+ }
+ }
+ }
+
+ test("linear regression model testset evaluation summary with weighted samples") {
+ Seq("auto", "l-bfgs", "normal").foreach { solver =>
+ val trainer1 = new LinearRegression().setSolver(solver)
+ val trainer2 = new LinearRegression().setSolver(solver).setWeightCol("weight")
+
+ Seq(0.25, 1.0, 10.0, 50.00).foreach { w =>
+ val model1 = trainer1.fit(datasetWithDenseFeature)
+ val model2 = trainer2.fit(datasetWithDenseFeature.withColumn("weight", lit(w)))
+ val testSummary1 = model1.evaluate(datasetWithDenseFeature)
+ val testSummary2 = model2.evaluate(datasetWithDenseFeature.withColumn("weight", lit(w)))
+ assert(testSummary1.explainedVariance ~== testSummary2.explainedVariance relTol 1e-6)
+ assert(testSummary1.meanAbsoluteError ~== testSummary2.meanAbsoluteError relTol 1e-6)
+ assert(testSummary1.meanSquaredError ~== testSummary2.meanSquaredError relTol 1e-6)
+ assert(testSummary1.rootMeanSquaredError ~==
+ testSummary2.rootMeanSquaredError relTol 1e-6)
+ assert(testSummary1.r2 ~== testSummary2.r2 relTol 1e-6)
+ assert(testSummary1.r2adj ~== testSummary2.r2adj relTol 1e-6)
+ }
+ }
+ }
+
test("linear regression with weighted samples") {
val sqlContext = spark.sqlContext
import sqlContext.implicits._