Skip to content

Commit 0820c04

Browse files
committed
Use SquaredL2Updater in LogisticRegressionWithSGD
SimpleUpdater ignores the regularizer, which leads to an unregularized LogReg. To enable the common L2 regularizer (and the corresponding regularization parameter) for logistic regression the SquaredL2Updater has to be used in SGD (see, e.g., [SVMWithSGD])
1 parent f493f79 commit 0820c04

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class LogisticRegressionWithSGD private (
8484
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
8585

8686
private val gradient = new LogisticGradient()
87-
private val updater = new SimpleUpdater()
87+
private val updater = new SquaredL2Updater()
8888
override val optimizer = new GradientDescent(gradient, updater)
8989
.setStepSize(stepSize)
9090
.setNumIterations(numIterations)

mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object LogisticRegressionSuite {
4343
offset: Double,
4444
scale: Double,
4545
nPoints: Int,
46-
seed: Int): Seq[LabeledPoint] = {
46+
seed: Int): Seq[LabeledPoint] = {
4747
val rnd = new Random(seed)
4848
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
4949

@@ -58,12 +58,15 @@ object LogisticRegressionSuite {
5858
}
5959

6060
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
61-
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
61+
def validatePrediction(
62+
predictions: Seq[Double],
63+
input: Seq[LabeledPoint],
64+
expectedAcc: Double = 0.83) {
6265
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
6366
prediction != expected.label
6467
}
6568
// At least 83% of the predictions should be on.
66-
((input.length - numOffPredictions).toDouble / input.length) should be > 0.83
69+
((input.length - numOffPredictions).toDouble / input.length) should be > expectedAcc
6770
}
6871

6972
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
@@ -155,6 +158,41 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
155158
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
156159
}
157160

161+
test("logistic regression with initial weights and non-default regularization parameter") {
162+
val nPoints = 10000
163+
val A = 2.0
164+
val B = -1.5
165+
166+
val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
167+
168+
val initialB = -1.0
169+
val initialWeights = Vectors.dense(initialB)
170+
171+
val testRDD = sc.parallelize(testData, 2)
172+
testRDD.cache()
173+
174+
// Use half as many iterations as the previous test.
175+
val lr = new LogisticRegressionWithSGD().setIntercept(true)
176+
lr.optimizer.
177+
setStepSize(10.0).
178+
setNumIterations(10).
179+
setRegParam(1.0)
180+
181+
val model = lr.run(testRDD, initialWeights)
182+
183+
// Test the weights
184+
assert(model.weights(0) ~== -430000.0 relTol 20000.0)
185+
assert(model.intercept ~== 370000.0 relTol 20000.0)
186+
187+
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
188+
val validationRDD = sc.parallelize(validationData, 2)
189+
// Test prediction on RDD.
190+
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8)
191+
192+
// Test prediction on Array.
193+
validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8)
194+
}
195+
158196
test("logistic regression with initial weights with LBFGS") {
159197
val nPoints = 10000
160198
val A = 2.0

0 commit comments

Comments
 (0)