Skip to content

Commit aa7e768

Browse files
committed
ML LinearRegression supports bound constrained optimization.
1 parent 13538cf commit aa7e768

File tree

2 files changed

+357
-8
lines changed

2 files changed

+357
-8
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.regression
2020
import scala.collection.mutable
2121

2222
import breeze.linalg.{DenseVector => BDV}
23-
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
23+
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
2424
import breeze.stats.distributions.StudentsT
2525
import org.apache.hadoop.fs.Path
2626

@@ -33,7 +33,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
3333
import org.apache.spark.ml.linalg.BLAS._
3434
import org.apache.spark.ml.optim.WeightedLeastSquares
3535
import org.apache.spark.ml.PredictorParams
36-
import org.apache.spark.ml.param.ParamMap
36+
import org.apache.spark.ml.param.{Param, ParamMap}
3737
import org.apache.spark.ml.param.shared._
3838
import org.apache.spark.ml.util._
3939
import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -43,7 +43,7 @@ import org.apache.spark.mllib.util.MLUtils
4343
import org.apache.spark.rdd.RDD
4444
import org.apache.spark.sql.{DataFrame, Dataset, Row}
4545
import org.apache.spark.sql.functions._
46-
import org.apache.spark.sql.types.DoubleType
46+
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
4747
import org.apache.spark.storage.StorageLevel
4848

4949
/**
@@ -52,7 +52,36 @@ import org.apache.spark.storage.StorageLevel
5252
private[regression] trait LinearRegressionParams extends PredictorParams
5353
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
5454
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
55-
with HasAggregationDepth
55+
with HasAggregationDepth {
56+
57+
/**
58+
* The lower bound of coefficients if fitting under bound constrained optimization.
59+
* The bound vector size must be equal with the number of features in training dataset,
60+
* otherwise, it throws exception.
61+
* @group param
62+
*/
63+
@Since("2.2.0")
64+
val lowerBoundOfCoefficients: Param[Vector] = new Param(this, "lowerBoundOfCoefficients",
65+
"The lower bound of coefficients if fitting under bound constrained optimization.")
66+
67+
/** @group getParam */
68+
@Since("2.2.0")
69+
def getLowerBoundOfCoefficients: Vector = $(lowerBoundOfCoefficients)
70+
71+
/**
72+
* The upper bound of coefficients if fitting under bound constrained optimization.
73+
* The bound vector size must be equal with the number of features in training dataset,
74+
* otherwise, it throws exception.
75+
* @group param
76+
*/
77+
@Since("2.2.0")
78+
val upperBoundOfCoefficients: Param[Vector] = new Param(this, "upperBoundOfCoefficients",
79+
"The upper bound of coefficients if fitting under bound constrained optimization.")
80+
81+
/** @group getParam */
82+
@Since("2.2.0")
83+
def getUpperBoundOfCoefficients: Vector = $(upperBoundOfCoefficients)
84+
}
5685

5786
/**
5887
* Linear regression.
@@ -123,6 +152,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
123152
* For alpha in (0,1), the penalty is a combination of L1 and L2.
124153
* Default is 0.0 which is an L2 penalty.
125154
*
155+
* Note: Fitting under bound constrained optimization only supports L2 regularization,
156+
* so it throws exception if getting non-zero value from this param.
157+
*
126158
* @group setParam
127159
*/
128160
@Since("1.4.0")
@@ -193,11 +225,63 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
193225
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
194226
setDefault(aggregationDepth -> 2)
195227

228+
/**
229+
* Set the lower bound of coefficients if fitting under bound constrained optimization.
230+
*
231+
* @group setParam
232+
*/
233+
@Since("2.2.0")
234+
def setLowerBoundOfCoefficients(value: Vector): this.type = set(lowerBoundOfCoefficients, value)
235+
236+
/**
237+
* Set the upper bound of coefficients if fitting under bound constrained optimization.
238+
*
239+
* @group setParam
240+
*/
241+
@Since("2.2.0")
242+
def setUpperBoundOfCoefficients(value: Vector): this.type = set(upperBoundOfCoefficients, value)
243+
244+
private def usingBoundConstrainedOptimization: Boolean = {
245+
isSet(lowerBoundOfCoefficients) || isSet(upperBoundOfCoefficients)
246+
}
247+
248+
@Since("2.2.0")
249+
override def validateAndTransformSchema(
250+
schema: StructType,
251+
fitting: Boolean,
252+
featuresDataType: DataType): StructType = {
253+
if (usingBoundConstrainedOptimization && $(elasticNetParam) != 0.0) {
254+
logError("Fitting linear regression under bound constrained optimization only supports " +
255+
s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")
256+
}
257+
super.validateAndTransformSchema(schema, fitting, featuresDataType)
258+
}
259+
196260
override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
197261
// Extract the number of features before deciding optimization solver.
198262
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
199263
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
200264

265+
// Check params interaction is valid if fitting under bound constrained optimization.
266+
if (usingBoundConstrainedOptimization) {
267+
if ($(lowerBoundOfCoefficients).size != numFeatures ||
268+
$(upperBoundOfCoefficients).size != numFeatures) {
269+
logError("The size of coefficients bound mismatched with number of features: " +
270+
s"lowerBoundOfCoefficients size = ${getLowerBoundOfCoefficients.size}, " +
271+
s"upperBoundOfCoefficients size = ${getUpperBoundOfCoefficients.size}, " +
272+
s"number of features = $numFeatures.")
273+
}
274+
275+
val validBound = $(lowerBoundOfCoefficients).toArray.zip($(upperBoundOfCoefficients).toArray)
276+
.forall(x => x._1 <= x._2)
277+
if (!validBound) {
278+
logError("LowerBoundOfCoefficients should always less than or equal to " +
279+
"upperBoundOfCoefficients, but found: " +
280+
s"lowerBoundOfCoefficients = $getLowerBoundOfCoefficients, " +
281+
s"upperBoundOfCoefficients = $getUpperBoundOfCoefficients.")
282+
}
283+
}
284+
201285
val instances: RDD[Instance] = dataset.select(
202286
col($(labelCol)), w, col($(featuresCol))).rdd.map {
203287
case Row(label: Double, weight: Double, features: Vector) =>
@@ -209,8 +293,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
209293
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
210294
instr.logNumFeatures(numFeatures)
211295

212-
if (($(solver) == "auto" &&
213-
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
296+
if (($(solver) == "auto" && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) &&
297+
!usingBoundConstrainedOptimization || $(solver) == "normal") {
214298
// For low dimensional data, WeightedLeastSquares is more efficient since the
215299
// training algorithm only requires one pass through the data. (SPARK-10668)
216300

@@ -322,8 +406,30 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
322406
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
323407
$(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, $(aggregationDepth))
324408

409+
var initialValues: Array[Double] = null
410+
325411
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
326-
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
412+
if (usingBoundConstrainedOptimization) {
413+
val lowerBound = BDV[Double]($(lowerBoundOfCoefficients).toArray.zip(featuresStd)
414+
.map{ case (lb, xStd) => lb * xStd / yStd })
415+
val upperBound = BDV[Double]($(upperBoundOfCoefficients).toArray.zip(featuresStd)
416+
.map{ case (ub, xStd) => ub * xStd / yStd })
417+
initialValues = lowerBound.toArray.zip(upperBound.toArray).map { case (lb, ub) =>
418+
if (lb.isInfinity && ub.isInfinity) {
419+
0.0
420+
} else if (lb.isInfinity) {
421+
ub
422+
} else if (ub.isInfinity) {
423+
lb
424+
} else {
425+
lb + (ub - lb) / 2.0
426+
}
427+
}
428+
new BreezeLBFGSB(lowerBound, upperBound, $(maxIter), 10, $(tol))
429+
} else {
430+
initialValues = Array.fill(numFeatures)(0.0)
431+
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
432+
}
327433
} else {
328434
val standardizationParam = $(standardization)
329435
def effectiveL1RegFun = (index: Int) => {
@@ -338,10 +444,11 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
338444
if (featuresStd(index) != 0.0) effectiveL1RegParam / featuresStd(index) else 0.0
339445
}
340446
}
447+
initialValues = Array.fill(numFeatures)(0.0)
341448
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol))
342449
}
343450

344-
val initialCoefficients = Vectors.zeros(numFeatures)
451+
val initialCoefficients = Vectors.dense(initialValues)
345452
val states = optimizer.iterations(new CachedDiffFunction(costFun),
346453
initialCoefficients.asBreeze.toDenseVector)
347454

0 commit comments

Comments
 (0)