Skip to content

Commit baa0805

Browse files
author
DB Tsai
committed
touch up
1 parent d6234ba commit baa0805

File tree

2 files changed

+51
-37
lines changed

2 files changed

+51
-37
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ class LinearRegression(override val uid: String)
8585
setDefault(fitIntercept -> true)
8686

8787
/**
88-
* Set to enable scaling (standardization).
88+
* Whether to standardize the training features before fitting the model.
89+
* The coefficients of models will be always returned on the original scale,
90+
* so it will be transparent for users. Note that when no regularization,
91+
* with or without standardization, the models should be always converged to
92+
* the same solution.
8993
* Default is true.
9094
* @group setParam
9195
*/
@@ -178,7 +182,19 @@ class LinearRegression(override val uid: String)
178182
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
179183
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
180184
} else {
181-
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegParam, $(tol))
185+
def regParamL1Fun = (index: Int) => {
186+
if ($(standardization)) {
187+
effectiveL1RegParam
188+
} else {
189+
// If `standardization` is false, we still standardize the data
190+
// to improve the rate of convergence; as a result, we have to
191+
// perform this reverse standardization by penalizing each component
192+
// differently to get effectively the same objective function when
193+
// the training dataset is not standardized.
194+
if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0
195+
}
196+
}
197+
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
182198
}
183199

184200
val initialWeights = Vectors.zeros(numFeatures)
@@ -464,6 +480,7 @@ class LinearRegressionSummary private[regression] (
464480
* @param weights The weights/coefficients corresponding to the features.
465481
* @param labelStd The standard deviation value of the label.
466482
* @param labelMean The mean value of the label.
483+
* @param fitIntercept Whether to fit an intercept term.
467484
* @param featuresStd The standard deviation values of the features.
468485
* @param featuresMean The mean values of the features.
469486
*/
@@ -472,7 +489,6 @@ private class LeastSquaresAggregator(
472489
labelStd: Double,
473490
labelMean: Double,
474491
fitIntercept: Boolean,
475-
standardization: Boolean,
476492
featuresStd: Array[Double],
477493
featuresMean: Array[Double]) extends Serializable {
478494

@@ -519,11 +535,7 @@ private class LeastSquaresAggregator(
519535
val localGradientSumArray = gradientSumArray
520536
data.foreachActive { (index, value) =>
521537
if (featuresStd(index) != 0.0 && value != 0.0) {
522-
if (standardization) {
523-
localGradientSumArray(index) += diff * value / featuresStd(index)
524-
} else {
525-
localGradientSumArray(index) += diff * value
526-
}
538+
localGradientSumArray(index) += diff * value / featuresStd(index)
527539
}
528540
}
529541
lossSum += diff * diff / 2.0
@@ -590,43 +602,46 @@ private class LeastSquaresCostFun(
590602
val w = Vectors.fromBreeze(weights)
591603

592604
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
593-
labelMean, fitIntercept, standardization, featuresStd, featuresMean))(
605+
labelMean, fitIntercept, featuresStd, featuresMean))(
594606
seqOp = (c, v) => (c, v) match {
595607
case (aggregator, (label, features)) => aggregator.add(label, features)
596608
},
597609
combOp = (c1, c2) => (c1, c2) match {
598610
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
599611
})
600612

601-
// If we are not doing standardization go back to unscaled weights
602-
if (standardization) {
603-
// regVal is the sum of weight squares for L2 regularization
604-
val norm = brzNorm(weights, 2.0)
605-
val regVal = 0.5 * effectiveL2regParam * norm * norm
606-
607-
val loss = leastSquaresAggregator.loss + regVal
608-
val gradient = leastSquaresAggregator.gradient
609-
axpy(effectiveL2regParam, w, gradient)
613+
val totalGradientArray = leastSquaresAggregator.gradient.toArray
610614

611-
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
615+
val regVal = if (effectiveL2regParam == 0.0) {
616+
0.0
612617
} else {
613-
val unscaledWeights = weights.copy
614-
val len = unscaledWeights.length
615-
var i = 0
616-
while (i < len) {
617-
unscaledWeights(i) /= featuresStd(i)
618-
i += 1
618+
var sum = 0.0
619+
w.foreachActive { (index, value) =>
620+
// The following code will compute the loss of the regularization; also
621+
// the gradient of the regularization, and add back to totalGradientArray.
622+
sum += {
623+
if (standardization) {
624+
totalGradientArray(index) += effectiveL2regParam * value
625+
value * value
626+
} else {
627+
if (featuresStd(index) != 0.0) {
628+
// If `standardization` is false, we still standardize the data
629+
// to improve the rate of convergence; as a result, we have to
630+
// perform this reverse standardization by penalizing each component
631+
// differently to get effectively the same objective function when
632+
// the training dataset is not standardized.
633+
val temp = value / (featuresStd(index) * featuresStd(index))
634+
totalGradientArray(index) += effectiveL2regParam * temp
635+
value * temp
636+
} else {
637+
0.0
638+
}
639+
}
640+
}
619641
}
620-
val norm = brzNorm(unscaledWeights, 2.0)
621-
622-
val regVal = 0.5 * effectiveL2regParam * norm * norm
623-
624-
val loss = leastSquaresAggregator.loss + regVal
625-
val gradient = leastSquaresAggregator.gradient
626-
val mw = Vectors.dense(unscaledWeights.toArray)
627-
axpy(effectiveL2regParam, mw, gradient)
628-
629-
(loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
642+
0.5 * effectiveL2regParam * sum
630643
}
644+
645+
(leastSquaresAggregator.loss + regVal, new BDV(totalGradientArray))
631646
}
632647
}

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
151151
Then again with the data with no intercept:
152152
> weightsWithoutIntercept
153153
3 x 1 sparse Matrix of class "dgCMatrix"
154-
s0
154+
s0
155155
(Intercept) .
156156
as.numeric.data3.V2. 4.70011
157157
as.numeric.data3.V3. 7.19943
@@ -505,5 +505,4 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
505505
.zip(testSummary.residuals.select("residuals").collect())
506506
.forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 1E-5 }
507507
}
508-
509508
}

0 commit comments

Comments
 (0)