Skip to content

Commit e47c574

Browse files
committed
Add support for L2 without standardization.
1 parent 55d3a66 commit e47c574

File tree

2 files changed

+93
-12
lines changed

2 files changed

+93
-12
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,35 @@ private class LeastSquaresCostFun(
453453
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
454454
})
455455

456-
// regVal is the sum of weight squares for L2 regularization
457-
val norm = brzNorm(weights, 2.0)
458-
val regVal = 0.5 * effectiveL2regParam * norm * norm
456+
// If we are not doing standardization go back to unscaled weights
457+
if (standardization) {
458+
// regVal is the sum of weight squares for L2 regularization
459+
val norm = brzNorm(weights, 2.0)
460+
val regVal = 0.5 * effectiveL2regParam * norm * norm
459461

460-
val loss = leastSquaresAggregator.loss + regVal
461-
val gradient = leastSquaresAggregator.gradient
462-
axpy(effectiveL2regParam, w, gradient)
462+
val loss = leastSquaresAggregator.loss + regVal
463+
val gradient = leastSquaresAggregator.gradient
464+
axpy(effectiveL2regParam, w, gradient)
463465

464-
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
466+
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
467+
} else {
468+
val unscaledWeights = weights.copy
469+
val len = unscaledWeights.length
470+
var i = 0
471+
while (i < len) {
472+
unscaledWeights(i) /= featuresStd(i)
473+
i += 1
474+
}
475+
val norm = brzNorm(unscaledWeights, 2.0)
476+
477+
val regVal = 0.5 * effectiveL2regParam * norm * norm
478+
479+
val loss = leastSquaresAggregator.loss + regVal
480+
val gradient = leastSquaresAggregator.gradient
481+
val mw = Vectors.dense(unscaledWeights.toArray)
482+
axpy(effectiveL2regParam, mw, gradient)
483+
484+
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
485+
}
465486
}
466487
}

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
125125
assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
126126
}
127127

128-
test("linear regression with intercept with L1 regularization") {
128+
test("linear regression with intercept with L1 regularization with standardization") {
129129
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
130130
val model = trainer.fit(dataset)
131131

@@ -153,7 +153,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
153153
}
154154
}
155155

156-
test("linear regression with intercept with L1 regularization with standardization turned off") {
156+
test("linear regression with intercept with L1 regularization without standardization") {
157157
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
158158
.setStandardization(false)
159159
val model = trainer.fit(dataset)
@@ -215,9 +215,9 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
215215
* > weights
216216
* 3 x 1 sparse Matrix of class "dgCMatrix"
217217
* s0
218-
* (Intercept) 6.328062
219-
* as.numeric.data.V2. 3.222034
220-
* as.numeric.data.V3. 4.926260
218+
* (Intercept) 5.269376
219+
* as.numeric.data.V2. 3.736216
220+
* as.numeric.data.V3. 5.712356)
221221
*/
222222
val interceptR = 5.269376
223223
val weightsR = Array(3.736216, 5.712356)
@@ -234,6 +234,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
234234
}
235235
}
236236

237+
test("linear regression with intercept with L2 regularization without standardization") {
238+
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
239+
.setStandardization(false)
240+
val model = trainer.fit(dataset)
241+
242+
/**
243+
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
244+
* standardize=FALSE))
245+
* > weights
246+
* 3 x 1 sparse Matrix of class "dgCMatrix"
247+
* s0
248+
* (Intercept) 5.791109
249+
* as.numeric.data.V2. 3.435466
250+
* as.numeric.data.V3. 5.910406
251+
*/
252+
val interceptR = 5.791109
253+
val weightsR = Array(3.435466, 5.910406)
254+
255+
assert(model.intercept ~== interceptR relTol 1E-3)
256+
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
257+
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
258+
259+
model.transform(dataset).select("features", "prediction").collect().foreach {
260+
case Row(features: DenseVector, prediction1: Double) =>
261+
val prediction2 =
262+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
263+
assert(prediction1 ~== prediction2 relTol 1E-5)
264+
}
265+
}
266+
237267
test("linear regression without intercept with L2 regularization") {
238268
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
239269
.setFitIntercept(false)
@@ -292,6 +322,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
292322
}
293323
}
294324

325+
test("linear regression with intercept with ElasticNet regularization without standardization") {
326+
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
327+
.setStandardization(false)
328+
val model = trainer.fit(dataset)
329+
330+
/**
331+
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6
332+
* standardize=FALSE))
333+
* > weights
334+
* 3 x 1 sparse Matrix of class "dgCMatrix"
335+
* s0
336+
* (Intercept) 6.114723
337+
* as.numeric.data.V2. 3.409937
338+
* as.numeric.data.V3. 6.146531
339+
*/
340+
val interceptR = 6.114723
341+
val weightsR = Array(3.409937, 6.146531)
342+
343+
assert(model.intercept ~== interceptR relTol 1E-3)
344+
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
345+
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
346+
347+
model.transform(dataset).select("features", "prediction").collect().foreach {
348+
case Row(features: DenseVector, prediction1: Double) =>
349+
val prediction2 =
350+
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
351+
assert(prediction1 ~== prediction2 relTol 1E-5)
352+
}
353+
}
354+
295355
test("linear regression without intercept with ElasticNet regularization") {
296356
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
297357
.setFitIntercept(false)

0 commit comments

Comments
 (0)