Skip to content

Commit 4eb078d

Browse files
DB TsaiDB Tsai
authored andcommitted
first commit
1 parent 4d9e560 commit 4eb078d

File tree

7 files changed

+491
-57
lines changed

7 files changed

+491
-57
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,43 @@ trait HasRegParam extends Params {
4242
final def getRegParam: Double = getOrDefault(regParam)
4343
}
4444

45+
/**
46+
* :: DeveloperApi ::
47+
* Trait for shared param elasticNetParam.
48+
*/
49+
@DeveloperApi
50+
trait HasElasticNetParam extends HasRegParam {
51+
52+
/**
53+
* param for elastic net regularization parameter
54+
* @group param
55+
*/
56+
final val elasticNetParam: DoubleParam =
57+
new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter")
58+
59+
/** @group getParam */
60+
def getElasticNetParam: Double = getOrDefault(elasticNetParam)
61+
}
62+
63+
/**
64+
* :: DeveloperApi ::
65+
* Trait for shared param tol.
66+
*/
67+
@DeveloperApi
68+
trait HasTol extends Params {
69+
70+
/**
71+
* param for the convergence tolerance in the iterative algorithms;
72+
* smaller value will lead to higher accuracy with the cost of more iterations
73+
* @group param
74+
*/
75+
final val tol: DoubleParam =
76+
new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
77+
78+
/** @group getParam */
79+
def getTol: Double = getOrDefault(tol)
80+
}
81+
4582
/**
4683
* :: DeveloperApi ::
4784
* Trait for shared param maxIter.

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 268 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,30 @@
1717

1818
package org.apache.spark.ml.regression
1919

20+
import org.apache.spark.mllib.linalg.BLAS.dot
21+
import org.apache.spark.rdd.RDD
22+
23+
import scala.collection.mutable.ArrayBuffer
24+
25+
import breeze.linalg.{norm => brzNorm, DenseVector => BDV, SparseVector => BSV}
26+
import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
27+
import breeze.optimize.{CachedDiffFunction, DiffFunction}
28+
2029
import org.apache.spark.annotation.AlphaComponent
30+
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasTol}
2131
import org.apache.spark.ml.param.{Params, ParamMap}
22-
import org.apache.spark.ml.param.shared._
23-
import org.apache.spark.mllib.linalg.{BLAS, Vector}
24-
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
32+
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
33+
import org.apache.spark.mllib.linalg.BLAS._
34+
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
35+
import org.apache.spark.mllib.regression.LabeledPoint
2536
import org.apache.spark.sql.DataFrame
2637
import org.apache.spark.storage.StorageLevel
2738

28-
2939
/**
3040
* Params for linear regression.
3141
*/
3242
private[regression] trait LinearRegressionParams extends RegressorParams
33-
with HasRegParam with HasMaxIter
34-
43+
with HasElasticNetParam with HasMaxIter with HasTol
3544

3645
/**
3746
* :: AlphaComponent ::
@@ -42,34 +51,116 @@ private[regression] trait LinearRegressionParams extends RegressorParams
4251
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
4352
with LinearRegressionParams {
4453

45-
setDefault(regParam -> 0.1, maxIter -> 100)
46-
47-
/** @group setParam */
54+
/**
55+
* Set the regularization parameter.
56+
* Default is 0.0.
57+
* @group setParam
58+
*/
4859
def setRegParam(value: Double): this.type = set(regParam, value)
60+
setRegParam(0.0)
61+
62+
/**
63+
* Set the ElasticNet mixing parameter.
64+
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
65+
* For 0 < alpha < 1, the penalty is a combination of L1 and L2.
66+
* Default is 0.0 which is an L2 penalty.
67+
* @group setParam
68+
*/
69+
def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
70+
setElasticNetParam(0.0)
4971

50-
/** @group setParam */
72+
/**
73+
* Set the maximal number of iterations.
74+
* Default is 100.
75+
* @group setParam
76+
*/
5177
def setMaxIter(value: Int): this.type = set(maxIter, value)
78+
setMaxIter(100)
79+
80+
/**
81+
* Set the convergence tolerance of iterations.
82+
* Smaller value will lead to higher accuracy with the cost of more iterations.
83+
* Default is 1E-9.
84+
* @group setParam
85+
*/
86+
def setTol(value: Double): this.type = set(tol, value)
87+
setTol(1E-9)
5288

5389
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
54-
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
55-
val oldDataset = extractLabeledPoints(dataset, paramMap)
90+
// Extract columns from data. If dataset is persisted, do not persist instances.
91+
val instances = extractLabeledPoints(dataset, paramMap).map {
92+
case LabeledPoint(label: Double, features: Vector) => (label, features)
93+
}
5694
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
5795
if (handlePersistence) {
58-
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
96+
instances.persist(StorageLevel.MEMORY_AND_DISK)
97+
}
98+
99+
// TODO: Benchmark if using two MultivariateOnlineSummarizer will be faster
100+
// than appending the label into the vector.
101+
val summary = instances.map { case (label: Double, features: Vector) =>
102+
Vectors.fromBreeze(features.toBreeze match {
103+
case dv: BDV[Double] => BDV.vertcat(dv, new BDV[Double](Array(label)))
104+
case sv: BSV[Double] => BSV.vertcat(sv, new BSV[Double](Array(0), Array(label), 1))
105+
case v: Any =>
106+
throw new IllegalArgumentException("Do not support vector type " + v.getClass)
107+
})}.treeAggregate(new MultivariateOnlineSummarizer)(
108+
(aggregator, data) => aggregator.add(data),
109+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
110+
111+
val numFeatures = summary.mean.size - 1
112+
val yMean = summary.mean(numFeatures)
113+
val yStd = math.sqrt(summary.variance(numFeatures))
114+
115+
val effectiveRegParam = paramMap(regParam) / yStd
116+
val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
117+
val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
118+
119+
val costFun = new LeastSquaresCostFun(
120+
instances,
121+
yStd, yMean,
122+
summary.variance.toArray.slice(0, numFeatures).map(Math.sqrt(_)).toArray,
123+
summary.mean.toArray.slice(0, numFeatures).toArray,
124+
effectiveL2RegParam)
125+
126+
val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
127+
new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
128+
} else {
129+
new BreezeOWLQN[Int, BDV[Double]](
130+
paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
59131
}
60132

61-
// Train model
62-
val lr = new LinearRegressionWithSGD()
63-
lr.optimizer
64-
.setRegParam(paramMap(regParam))
65-
.setNumIterations(paramMap(maxIter))
66-
val model = lr.run(oldDataset)
67-
val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
133+
val initialWeights = Vectors.zeros(numFeatures)
134+
val states =
135+
optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
136+
137+
var state = states.next()
138+
val lossHistory = new ArrayBuffer[Double](paramMap(maxIter))
139+
while(states.hasNext) {
140+
lossHistory.append(state.value)
141+
state = states.next()
142+
}
143+
lossHistory.append(state.value)
144+
145+
val weights = {
146+
val rawWeights = state.x.toArray
147+
val std = summary.variance.toArray.slice(0, numFeatures).map(Math.sqrt(_)).toArray
148+
require(rawWeights.size == std.size)
149+
150+
var i = 0
151+
while (i < rawWeights.size) {
152+
rawWeights(i) = if (std(i) != 0.0) rawWeights(i) * yStd / std(i) else 0.0
153+
i += 1
154+
}
155+
Vectors.dense(rawWeights)
156+
}
157+
158+
val intercept = yMean - dot(weights, Vectors.dense(summary.mean.toArray.slice(0, numFeatures)))
68159

69160
if (handlePersistence) {
70-
oldDataset.unpersist()
161+
instances.unpersist()
71162
}
72-
lrm
163+
new LinearRegressionModel(this, paramMap, weights, intercept)
73164
}
74165
}
75166

@@ -97,3 +188,158 @@ class LinearRegressionModel private[ml] (
97188
m
98189
}
99190
}
191+
192+
private class LeastSquaresAggregator(
193+
weights: Vector,
194+
labelStd: Double,
195+
labelMean: Double,
196+
featuresStd: Array[Double],
197+
featuresMean: Array[Double]) extends Serializable {
198+
199+
private var totalCnt: Long = 0
200+
private var lossSum = 0.0
201+
private var diffSum = 0.0
202+
203+
private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = {
204+
val weightsArray = weights.toArray.clone()
205+
var sum = 0.0
206+
var i = 0
207+
while (i < weights.size) {
208+
if (featuresStd(i) != 0.0) {
209+
weightsArray(i) /= featuresStd(i)
210+
sum += weightsArray(i) * featuresMean(i)
211+
} else {
212+
weightsArray(i) = 0.0
213+
}
214+
i += 1
215+
}
216+
(weightsArray, -sum, weightsArray.length)
217+
}
218+
219+
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
220+
private var gradientSumArray: Array[Double] = Array.ofDim[Double](dim)
221+
222+
/**
223+
* Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
224+
* of the objective function.
225+
*
226+
* @param label The label for this data point.
227+
* @param data The features for one data point in dense/sparse vector format to be added
228+
* into this aggregator.
229+
* @return This LeastSquaresAggregator object.
230+
*/
231+
def add(label: Double, data: Vector): this.type = {
232+
require(dim == data.size, s"Dimensions mismatch when adding new sample." +
233+
s" Expecting $dim but got ${data.size}.")
234+
235+
val diff = dot(data, effectiveWeightsVector) - (label - labelMean) / labelStd + offset
236+
237+
if (diff != 0) {
238+
val localGradientSumArray = gradientSumArray
239+
data.foreachActive { (index, value) =>
240+
if (featuresStd(index) != 0.0 && value != 0.0) {
241+
localGradientSumArray(index) += diff * value / featuresStd(index)
242+
}
243+
}
244+
lossSum += diff * diff / 2.0
245+
diffSum += diff
246+
}
247+
248+
totalCnt += 1
249+
this
250+
}
251+
252+
/**
253+
* Merge another LeastSquaresAggregator, and update the loss and gradient
254+
* of the objective function.
255+
* (Note that it's in place merging; as a result, `this` object will be modified.)
256+
*
257+
* @param other The other LeastSquaresAggregator to be merged.
258+
* @return This LeastSquaresAggregator object.
259+
*/
260+
def merge(other: LeastSquaresAggregator): this.type = {
261+
if (this.totalCnt != 0 && other.totalCnt != 0) {
262+
require(dim == other.dim, s"Dimensions mismatch when merging with another summarizer. " +
263+
s"Expecting $dim but got ${other.dim}.")
264+
totalCnt += other.totalCnt
265+
lossSum += other.lossSum
266+
diffSum += other.diffSum
267+
268+
var i = 0
269+
val localThisGradientSumArray = gradientSumArray
270+
val localOtherGradientSumArray = other.gradientSumArray
271+
while (i < dim) {
272+
localThisGradientSumArray(i) += localOtherGradientSumArray(i)
273+
i += 1
274+
}
275+
} else if (totalCnt == 0 && other.totalCnt != 0) {
276+
this.totalCnt = other.totalCnt
277+
this.lossSum = other.lossSum
278+
this.diffSum = other.diffSum
279+
this.gradientSumArray = other.gradientSumArray.clone
280+
}
281+
this
282+
}
283+
284+
def count: Long = totalCnt
285+
286+
def loss: Double = lossSum / totalCnt
287+
288+
def gradient: Vector = {
289+
val result = Vectors.dense(gradientSumArray.clone)
290+
291+
val correction = {
292+
val temp = effectiveWeightsArray.clone
293+
var i = 0
294+
while (i < temp.size) {
295+
temp(i) *= featuresMean(i)
296+
i += 1
297+
}
298+
Vectors.dense(temp)
299+
}
300+
301+
axpy(-diffSum, result, correction)
302+
scal(1.0 / totalCnt, result)
303+
result
304+
}
305+
}
306+
307+
private class LeastSquaresCostFun(
308+
data: RDD[(Double, Vector)],
309+
labelStd: Double,
310+
labelMean: Double,
311+
featuresStd: Array[Double],
312+
featuresMean: Array[Double],
313+
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
314+
315+
override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
316+
val w = Vectors.fromBreeze(weights)
317+
318+
val leastSquaresAggregator = data.treeAggregate(
319+
new LeastSquaresAggregator(w, labelStd, labelMean, featuresStd, featuresMean))(
320+
seqOp = (c, v) => (c, v) match {
321+
case (aggregator, (label, features)) => aggregator.add(label, features)
322+
},
323+
combOp = (c1, c2) => (c1, c2) match {
324+
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
325+
})
326+
327+
/**
328+
* regVal is sum of weight squares if it's L2 updater;
329+
* for other updater, the same logic is followed.
330+
*/
331+
val norm = brzNorm(weights, 2.0)
332+
val regVal = 0.5 * effectiveL2regParam * norm * norm
333+
334+
val loss = leastSquaresAggregator.loss + regVal
335+
// The following gradientTotal is actually the regularization part of gradient.
336+
// Will add the gradientSum computed from the data with weights in the next step.
337+
val gradientTotal = w.copy
338+
scal(effectiveL2regParam, gradientTotal)
339+
340+
// gradientTotal = gradient + gradientTotal
341+
axpy(1.0, leastSquaresAggregator.gradient, gradientTotal)
342+
343+
(loss, gradientTotal.toBreeze.asInstanceOf[BDV[Double]])
344+
}
345+
}

0 commit comments

Comments
 (0)