Skip to content

Commit 5af16cb

Browse files
committed
Update code and test.
1 parent aa7e768 commit 5af16cb

File tree

2 files changed

+80
-26
lines changed

2 files changed

+80
-26
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
5757
/**
5858
* The lower bound of coefficients if fitting under bound constrained optimization.
5959
* The bound vector size must be equal with the number of features in training dataset,
60-
* otherwise, it throws exception.
60+
* otherwise, throws exception.
6161
* @group param
6262
*/
6363
@Since("2.2.0")
@@ -71,7 +71,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams
7171
/**
7272
* The upper bound of coefficients if fitting under bound constrained optimization.
7373
* The bound vector size must be equal with the number of features in training dataset,
74-
* otherwise, it throws exception.
74+
* otherwise, throws exception.
7575
* @group param
7676
*/
7777
@Since("2.2.0")
@@ -153,7 +153,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
153153
* Default is 0.0 which is an L2 penalty.
154154
*
155155
* Note: Fitting under bound constrained optimization only supports L2 regularization,
156-
* so it throws exception if getting non-zero value from this param.
156+
* so throws exception if this param is non-zero value.
157157
*
158158
* @group setParam
159159
*/
@@ -203,6 +203,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
203203
* The Normal Equations solver will be used when possible, but this will automatically fall
204204
* back to iterative optimization methods when needed.
205205
*
206+
* Note: Fitting under bound constrained optimization does not support "normal" solver.
207+
*
206208
* @group setParam
207209
*/
208210
@Since("1.6.0")
@@ -245,14 +247,38 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
245247
isSet(lowerBoundOfCoefficients) || isSet(upperBoundOfCoefficients)
246248
}
247249

250+
private def assertBoundConstrainedOptimizationParamsValid(numFeatures: Int): Unit = {
251+
if (isSet(lowerBoundOfCoefficients)) {
252+
require($(lowerBoundOfCoefficients).size == numFeatures,
253+
"The size of lowerBoundOfCoefficients mismatched with number of features: " +
254+
s"lowerBoundOfCoefficients size = ${getLowerBoundOfCoefficients.size}, " +
255+
s"number of features = $numFeatures.")
256+
}
257+
if (isSet(upperBoundOfCoefficients)) {
258+
require($(upperBoundOfCoefficients).size == numFeatures,
259+
"The size of upperBoundOfCoefficients mismatched with number of features: " +
260+
s"upperBoundOfCoefficients size = ${getUpperBoundOfCoefficients.size}, " +
261+
s"number of features = $numFeatures.")
262+
}
263+
if (isSet(lowerBoundOfCoefficients) && isSet(upperBoundOfCoefficients)) {
264+
require($(lowerBoundOfCoefficients).toArray.zip($(upperBoundOfCoefficients).toArray)
265+
.forall(x => x._1 <= x._2), "LowerBoundOfCoefficients should always " +
266+
"less than or equal to upperBoundOfCoefficients, but found: " +
267+
s"lowerBoundOfCoefficients = $getLowerBoundOfCoefficients, " +
268+
s"upperBoundOfCoefficients = $getUpperBoundOfCoefficients.")
269+
}
270+
}
271+
248272
@Since("2.2.0")
249273
override def validateAndTransformSchema(
250274
schema: StructType,
251275
fitting: Boolean,
252276
featuresDataType: DataType): StructType = {
253-
if (usingBoundConstrainedOptimization && $(elasticNetParam) != 0.0) {
254-
logError("Fitting linear regression under bound constrained optimization only supports " +
255-
s"L2 regularization, but got elasticNetParam = $getElasticNetParam.")
277+
if (usingBoundConstrainedOptimization) {
278+
require($(solver) != "normal", "Fitting under bound constrained optimization " +
279+
"does not support normal solver.")
280+
require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " +
281+
s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.")
256282
}
257283
super.validateAndTransformSchema(schema, fitting, featuresDataType)
258284
}
@@ -264,22 +290,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
264290

265291
// Check params interaction is valid if fitting under bound constrained optimization.
266292
if (usingBoundConstrainedOptimization) {
267-
if ($(lowerBoundOfCoefficients).size != numFeatures ||
268-
$(upperBoundOfCoefficients).size != numFeatures) {
269-
logError("The size of coefficients bound mismatched with number of features: " +
270-
s"lowerBoundOfCoefficients size = ${getLowerBoundOfCoefficients.size}, " +
271-
s"upperBoundOfCoefficients size = ${getUpperBoundOfCoefficients.size}, " +
272-
s"number of features = $numFeatures.")
273-
}
274-
275-
val validBound = $(lowerBoundOfCoefficients).toArray.zip($(upperBoundOfCoefficients).toArray)
276-
.forall(x => x._1 <= x._2)
277-
if (!validBound) {
278-
logError("LowerBoundOfCoefficients should always less than or equal to " +
279-
"upperBoundOfCoefficients, but found: " +
280-
s"lowerBoundOfCoefficients = $getLowerBoundOfCoefficients, " +
281-
s"upperBoundOfCoefficients = $getUpperBoundOfCoefficients.")
282-
}
293+
assertBoundConstrainedOptimizationParamsValid(numFeatures)
283294
}
284295

285296
val instances: RDD[Instance] = dataset.select(
@@ -410,10 +421,18 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
410421

411422
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
412423
if (usingBoundConstrainedOptimization) {
413-
val lowerBound = BDV[Double]($(lowerBoundOfCoefficients).toArray.zip(featuresStd)
414-
.map{ case (lb, xStd) => lb * xStd / yStd })
415-
val upperBound = BDV[Double]($(upperBoundOfCoefficients).toArray.zip(featuresStd)
416-
.map{ case (ub, xStd) => ub * xStd / yStd })
424+
val lowerBound = if (isSet(lowerBoundOfCoefficients)) {
425+
BDV[Double]($(lowerBoundOfCoefficients).toArray.zip(featuresStd)
426+
.map{ case (lb, xStd) => lb * xStd / yStd })
427+
} else {
428+
BDV[Double](Array.fill(numFeatures)(Double.NegativeInfinity))
429+
}
430+
val upperBound = if (isSet(upperBoundOfCoefficients)) {
431+
BDV[Double]($(upperBoundOfCoefficients).toArray.zip(featuresStd)
432+
.map{ case (ub, xStd) => ub * xStd / yStd })
433+
} else {
434+
BDV[Double](Array.fill(numFeatures)(Double.PositiveInfinity))
435+
}
417436
initialValues = lowerBound.toArray.zip(upperBound.toArray).map { case (lb, ub) =>
418437
if (lb.isInfinity && ub.isInfinity) {
419438
0.0

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,41 @@ class LinearRegressionSuite
167167
assert(model.numFeatures === numFeatures)
168168
}
169169

170+
test("linear regression: illegal params") {
171+
val lowerBoundOfCoefficients = Vectors.dense(Array(1.0, 1.0))
172+
val upperBoundOfCoefficients1 = Vectors.dense(Array(-2.0, 2.0))
173+
val upperBoundOfCoefficients2 = Vectors.dense(Array(2.0))
174+
175+
val lir = new LinearRegression().setLowerBoundOfCoefficients(lowerBoundOfCoefficients)
176+
177+
// Work well when only set bound in one side.
178+
lir.fit(datasetWithDenseFeature)
179+
180+
withClue("bound constrained optimization does not support normal solver.") {
181+
intercept[IllegalArgumentException] {
182+
lir.setSolver("normal").fit(datasetWithDenseFeature)
183+
}
184+
}
185+
186+
withClue("bound constrained optimization only supports L2 regularization") {
187+
intercept[IllegalArgumentException] {
188+
lir.setElasticNetParam(1.0).fit(datasetWithDenseFeature)
189+
}
190+
}
191+
192+
withClue("lowerBoundOfCoefficients should less than or equal to upperBoundOfCoefficients") {
193+
intercept[IllegalArgumentException] {
194+
lir.setUpperBoundOfCoefficients(upperBoundOfCoefficients1).fit(datasetWithDenseFeature)
195+
}
196+
}
197+
198+
withClue("the size of coefficients bound mismatched with number of features") {
199+
intercept[IllegalArgumentException] {
200+
lir.setUpperBoundOfCoefficients(upperBoundOfCoefficients2).fit(datasetWithDenseFeature)
201+
}
202+
}
203+
}
204+
170205
test("linear regression handles singular matrices") {
171206
// check for both constant columns with intercept (zero std) and collinear
172207
val singularDataConstantColumn = sc.parallelize(Seq(

0 commit comments

Comments
 (0)