Skip to content

Commit d679843

Browse files
mengxrtdas
authored andcommitted
[SPARK-1327] GLM needs to check addIntercept for intercept and weights
GLM needs to check addIntercept for intercept and weights. The current implementation always uses the first weight as intercept. Added a test for training without adding intercept. JIRA: https://spark-project.atlassian.net/browse/SPARK-1327 Author: Xiangrui Meng <[email protected]> Closes #236 from mengxr/glm and squashes the following commits: bcac1ac [Xiangrui Meng] add two tests to ensure {Lasso, Ridge}.setIntercept will throw an exceptions a104072 [Xiangrui Meng] remove protected to be compatible with 0.9 0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used
1 parent 1fa48d9 commit d679843

File tree

7 files changed

+86
-37
lines changed

7 files changed

+86
-37
lines changed

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
136136

137137
// Prepend an extra variable consisting of all 1.0's for the intercept.
138138
val data = if (addIntercept) {
139-
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
139+
input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features))
140140
} else {
141141
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
142142
}
143143

144144
val initialWeightsWithIntercept = if (addIntercept) {
145-
initialWeights.+:(1.0)
145+
0.0 +: initialWeights
146146
} else {
147147
initialWeights
148148
}
149149

150-
val weights = optimizer.optimize(data, initialWeightsWithIntercept)
151-
val intercept = weights(0)
152-
val weightsScaled = weights.tail
150+
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
153151

154-
val model = createModel(weightsScaled, intercept)
152+
val (intercept, weights) = if (addIntercept) {
153+
(weightsWithIntercept(0), weightsWithIntercept.tail)
154+
} else {
155+
(0.0, weightsWithIntercept)
156+
}
157+
158+
logInfo("Final weights " + weights.mkString(","))
159+
logInfo("Final intercept " + intercept)
155160

156-
logInfo("Final model weights " + model.weights.mkString(","))
157-
logInfo("Final model intercept " + model.intercept)
158-
model
161+
createModel(weights, intercept)
159162
}
160163
}

mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ class LassoModel(
3636
extends GeneralizedLinearModel(weights, intercept)
3737
with RegressionModel with Serializable {
3838

39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
39+
override def predictPoint(
40+
dataMatrix: DoubleMatrix,
41+
weightMatrix: DoubleMatrix,
42+
intercept: Double): Double = {
4143
dataMatrix.dot(weightMatrix) + intercept
4244
}
4345
}
@@ -66,7 +68,7 @@ class LassoWithSGD private (
6668
.setMiniBatchFraction(miniBatchFraction)
6769

6870
// We don't want to penalize the intercept, so set this to false.
69-
setIntercept(false)
71+
super.setIntercept(false)
7072

7173
var yMean = 0.0
7274
var xColMean: DoubleMatrix = _
@@ -77,10 +79,16 @@ class LassoWithSGD private (
7779
*/
7880
def this() = this(1.0, 100, 1.0, 1.0)
7981

80-
def createModel(weights: Array[Double], intercept: Double) = {
81-
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
82+
override def setIntercept(addIntercept: Boolean): this.type = {
83+
// TODO: Support adding intercept.
84+
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
85+
this
86+
}
87+
88+
override def createModel(weights: Array[Double], intercept: Double) = {
89+
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
8290
val weightsScaled = weightsMat.div(xColSd)
83-
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
91+
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
8492

8593
new LassoModel(weightsScaled.data, interceptScaled)
8694
}

mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
3131
* @param intercept Intercept computed for this model.
3232
*/
3333
class LinearRegressionModel(
34-
override val weights: Array[Double],
35-
override val intercept: Double)
36-
extends GeneralizedLinearModel(weights, intercept)
37-
with RegressionModel with Serializable {
38-
39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
34+
override val weights: Array[Double],
35+
override val intercept: Double)
36+
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
37+
38+
override def predictPoint(
39+
dataMatrix: DoubleMatrix,
40+
weightMatrix: DoubleMatrix,
41+
intercept: Double): Double = {
4142
dataMatrix.dot(weightMatrix) + intercept
4243
}
4344
}
@@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
5556
var stepSize: Double,
5657
var numIterations: Int,
5758
var miniBatchFraction: Double)
58-
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
59-
with Serializable {
59+
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
6060

6161
val gradient = new LeastSquaresGradient()
6262
val updater = new SimpleUpdater()
@@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
6969
*/
7070
def this() = this(1.0, 100, 1.0)
7171

72-
def createModel(weights: Array[Double], intercept: Double) = {
72+
override def createModel(weights: Array[Double], intercept: Double) = {
7373
new LinearRegressionModel(weights, intercept)
7474
}
7575
}

mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ class RidgeRegressionModel(
3636
extends GeneralizedLinearModel(weights, intercept)
3737
with RegressionModel with Serializable {
3838

39-
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
40-
intercept: Double) = {
39+
override def predictPoint(
40+
dataMatrix: DoubleMatrix,
41+
weightMatrix: DoubleMatrix,
42+
intercept: Double): Double = {
4143
dataMatrix.dot(weightMatrix) + intercept
4244
}
4345
}
@@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private (
6769
.setMiniBatchFraction(miniBatchFraction)
6870

6971
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
70-
setIntercept(false)
72+
super.setIntercept(false)
7173

7274
var yMean = 0.0
7375
var xColMean: DoubleMatrix = _
@@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private (
7880
*/
7981
def this() = this(1.0, 100, 1.0, 1.0)
8082

81-
def createModel(weights: Array[Double], intercept: Double) = {
82-
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
83+
override def setIntercept(addIntercept: Boolean): this.type = {
84+
// TODO: Support adding intercept.
85+
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
86+
this
87+
}
88+
89+
override def createModel(weights: Array[Double], intercept: Double) = {
90+
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
8391
val weightsScaled = weightsMat.div(xColSd)
8492
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
8593

mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
21-
import org.scalatest.BeforeAndAfterAll
2220
import org.scalatest.FunSuite
2321

24-
import org.apache.spark.SparkContext
2522
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
2623

2724
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -104,4 +101,10 @@ class LassoSuite extends FunSuite with LocalSparkContext {
104101
// Test prediction on Array.
105102
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
106103
}
104+
105+
test("do not support intercept") {
106+
intercept[UnsupportedOperationException] {
107+
new LassoWithSGD().setIntercept(true)
108+
}
109+
}
107110
}

mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
import org.scalatest.BeforeAndAfterAll
2120
import org.scalatest.FunSuite
2221

2322
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
@@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
5756
// Test prediction on Array.
5857
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
5958
}
59+
60+
// Test if we can correctly learn Y = 10*X1 + 10*X2
61+
test("linear regression without intercept") {
62+
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
63+
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
64+
val linReg = new LinearRegressionWithSGD().setIntercept(false)
65+
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
66+
67+
val model = linReg.run(testRDD)
68+
69+
assert(model.intercept === 0.0)
70+
assert(model.weights.length === 2)
71+
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
72+
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
73+
74+
val validationData = LinearDataGenerator.generateLinearInput(
75+
0.0, Array(10.0, 10.0), 100, 17)
76+
val validationRDD = sc.parallelize(validationData, 2).cache()
77+
78+
// Test prediction on RDD.
79+
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
80+
81+
// Test prediction on Array.
82+
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
83+
}
6084
}

mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@
1717

1818
package org.apache.spark.mllib.regression
1919

20-
2120
import org.jblas.DoubleMatrix
22-
import org.scalatest.BeforeAndAfterAll
2321
import org.scalatest.FunSuite
2422

2523
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
2624

27-
2825
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
2926

3027
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
@@ -74,4 +71,10 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
7471
assert(ridgeErr < linearErr,
7572
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
7673
}
74+
75+
test("do not support intercept") {
76+
intercept[UnsupportedOperationException] {
77+
new RidgeRegressionWithSGD().setIntercept(true)
78+
}
79+
}
7780
}

0 commit comments

Comments
 (0)