Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -961,51 +961,48 @@ class LinearRegressionSummary private[regression] (
private val privateModel: LinearRegressionModel,
private val diagInvAtWA: Array[Double]) extends Serializable {

@transient private val metrics = new RegressionMetrics(
predictions
.select(col(predictionCol), col(labelCol).cast(DoubleType))
.rdd
.map { case Row(pred: Double, label: Double) => (pred, label) },
!privateModel.getFitIntercept)
@transient private val metrics = {
val weightCol =
if (!privateModel.isDefined(privateModel.weightCol) || privateModel.getWeightCol.isEmpty) {
lit(1.0)
} else {
col(privateModel.getWeightCol).cast(DoubleType)
}

new RegressionMetrics(
predictions
.select(col(predictionCol), col(labelCol).cast(DoubleType), weightCol)
.rdd
.map { case Row(pred: Double, label: Double, weight: Double) => (pred, label, weight) },
!privateModel.getFitIntercept)
}

/**
* Returns the explained variance regression score.
* explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
* Reference: <a href="http://en.wikipedia.org/wiki/Explained_variation">
* Wikipedia explain variation</a>
*
* @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
* This will change in later Spark versions.
*/
@Since("1.5.0")
val explainedVariance: Double = metrics.explainedVariance

/**
* Returns the mean absolute error, which is a risk function corresponding to the
* expected value of the absolute error loss or l1-norm loss.
*
* @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
* This will change in later Spark versions.
*/
@Since("1.5.0")
val meanAbsoluteError: Double = metrics.meanAbsoluteError

/**
* Returns the mean squared error, which is a risk function corresponding to the
* expected value of the squared error loss or quadratic loss.
*
* @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
* This will change in later Spark versions.
*/
@Since("1.5.0")
val meanSquaredError: Double = metrics.meanSquaredError

/**
* Returns the root mean squared error, which is defined as the square root of
* the mean squared error.
*
* @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
* This will change in later Spark versions.
*/
@Since("1.5.0")
val rootMeanSquaredError: Double = metrics.rootMeanSquaredError
Expand All @@ -1014,9 +1011,6 @@ class LinearRegressionSummary private[regression] (
* Returns R^2^, the coefficient of determination.
* Reference: <a href="http://en.wikipedia.org/wiki/Coefficient_of_determination">
* Wikipedia coefficient of determination</a>
*
* @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
* This will change in later Spark versions.
*/
@Since("1.5.0")
val r2: Double = metrics.r2
Expand All @@ -1025,9 +1019,6 @@ class LinearRegressionSummary private[regression] (
* Returns Adjusted R^2^, the adjusted coefficient of determination.
* Reference: <a href="https://en.wikipedia.org/wiki/Coefficient_of_determination#Adjusted_R2">
* Wikipedia coefficient of determination</a>
*
* @note This ignores instance weights (setting all to 1.0) from `LinearRegression.weightCol`.
* This will change in later Spark versions.
*/
@Since("2.3.0")
val r2adj: Double = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.LinearDataGenerator
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.lit


class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {
Expand Down Expand Up @@ -899,6 +900,46 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe
}
}

test("linear regression model training summary with weighted samples") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer1 = new LinearRegression().setSolver(solver)
val trainer2 = new LinearRegression().setSolver(solver).setWeightCol("weight")

Seq(0.25, 1.0, 10.0, 50.00).foreach { w =>
val model1 = trainer1.fit(datasetWithDenseFeature)
val model2 = trainer2.fit(datasetWithDenseFeature.withColumn("weight", lit(w)))
assert(model1.summary.explainedVariance ~== model2.summary.explainedVariance relTol 1e-6)
assert(model1.summary.meanAbsoluteError ~== model2.summary.meanAbsoluteError relTol 1e-6)
assert(model1.summary.meanSquaredError ~== model2.summary.meanSquaredError relTol 1e-6)
assert(model1.summary.rootMeanSquaredError ~==
model2.summary.rootMeanSquaredError relTol 1e-6)
assert(model1.summary.r2 ~== model2.summary.r2 relTol 1e-6)
assert(model1.summary.r2adj ~== model2.summary.r2adj relTol 1e-6)
}
}
}

test("linear regression model testset evaluation summary with weighted samples") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer1 = new LinearRegression().setSolver(solver)
val trainer2 = new LinearRegression().setSolver(solver).setWeightCol("weight")

Seq(0.25, 1.0, 10.0, 50.00).foreach { w =>
val model1 = trainer1.fit(datasetWithDenseFeature)
val model2 = trainer2.fit(datasetWithDenseFeature.withColumn("weight", lit(w)))
val testSummary1 = model1.evaluate(datasetWithDenseFeature)
val testSummary2 = model2.evaluate(datasetWithDenseFeature.withColumn("weight", lit(w)))
assert(testSummary1.explainedVariance ~== testSummary2.explainedVariance relTol 1e-6)
assert(testSummary1.meanAbsoluteError ~== testSummary2.meanAbsoluteError relTol 1e-6)
assert(testSummary1.meanSquaredError ~== testSummary2.meanSquaredError relTol 1e-6)
assert(testSummary1.rootMeanSquaredError ~==
testSummary2.rootMeanSquaredError relTol 1e-6)
assert(testSummary1.r2 ~== testSummary2.r2 relTol 1e-6)
assert(testSummary1.r2adj ~== testSummary2.r2adj relTol 1e-6)
}
}
}

test("linear regression with weighted samples") {
val sqlContext = spark.sqlContext
import sqlContext.implicits._
Expand Down