Skip to content

Commit e0b047e

Browse files
zhengruifengyanboliang
authored andcommitted
[SPARK-18518][ML] HasSolver supports override
## What changes were proposed in this pull request? 1, make param support non-final with `finalFields` option 2, generate `HasSolver` with `finalFields = false` 3, override `solver` in LiR, GLR, and make MLPC inherit `HasSolver` ## How was this patch tested? existing tests Author: Ruifeng Zheng <[email protected]> Author: Zheng RuiFeng <[email protected]> Closes #16028 from zhengruifeng/param_non_final.
1 parent 37ef32e commit e0b047e

File tree

7 files changed

+82
-46
lines changed

7 files changed

+82
-46
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
2727
import org.apache.spark.ml.feature.LabeledPoint
2828
import org.apache.spark.ml.linalg.{Vector, Vectors}
2929
import org.apache.spark.ml.param._
30-
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol}
30+
import org.apache.spark.ml.param.shared._
3131
import org.apache.spark.ml.util._
3232
import org.apache.spark.sql.Dataset
3333

3434
/** Params for Multilayer Perceptron. */
3535
private[classification] trait MultilayerPerceptronParams extends PredictorParams
36-
with HasSeed with HasMaxIter with HasTol with HasStepSize {
36+
with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver {
37+
38+
import MultilayerPerceptronClassifier._
39+
3740
/**
3841
* Layer sizes including input size and output size.
3942
*
@@ -78,14 +81,10 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
7881
* @group expertParam
7982
*/
8083
@Since("2.0.0")
81-
final val solver: Param[String] = new Param[String](this, "solver",
84+
final override val solver: Param[String] = new Param[String](this, "solver",
8285
"The solver algorithm for optimization. Supported options: " +
83-
s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)",
84-
ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers))
85-
86-
/** @group expertGetParam */
87-
@Since("2.0.0")
88-
final def getSolver: String = $(solver)
86+
s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)",
87+
ParamValidators.inArray[String](supportedSolvers))
8988

9089
/**
9190
* The initial weights of the model.
@@ -101,7 +100,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams
101100
final def getInitialWeights: Vector = $(initialWeights)
102101

103102
setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128,
104-
solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03)
103+
solver -> LBFGS, stepSize -> 0.03)
105104
}
106105

107106
/** Label to vector converter. */

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ private[shared] object SharedParamsCodeGen {
8080
" 0)", isValid = "ParamValidators.gt(0)"),
8181
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
8282
"all instance weights as 1.0"),
83-
ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
84-
"empty, default value is 'auto'", Some("\"auto\"")),
83+
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
8584
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
8685
isValid = "ParamValidators.gtEq(2)", isExpertParam = true))
8786

@@ -99,6 +98,7 @@ private[shared] object SharedParamsCodeGen {
9998
defaultValueStr: Option[String] = None,
10099
isValid: String = "",
101100
finalMethods: Boolean = true,
101+
finalFields: Boolean = true,
102102
isExpertParam: Boolean = false) {
103103

104104
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
@@ -167,6 +167,11 @@ private[shared] object SharedParamsCodeGen {
167167
} else {
168168
"def"
169169
}
170+
val fieldStr = if (param.finalFields) {
171+
"final val"
172+
} else {
173+
"val"
174+
}
170175

171176
val htmlCompliantDoc = Utility.escape(doc)
172177

@@ -180,7 +185,7 @@ private[shared] object SharedParamsCodeGen {
180185
| * Param for $htmlCompliantDoc.
181186
| * @group ${groupStr(0)}
182187
| */
183-
| final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
188+
| $fieldStr $name: $Param = new $Param(this, "$name", "$doc"$isValid)
184189
|$setDefault
185190
| /** @group ${groupStr(1)} */
186191
| $methodStr get$Name: $T = $$($name)

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,17 +374,15 @@ private[ml] trait HasWeightCol extends Params {
374374
}
375375

376376
/**
377-
* Trait for shared param solver (default: "auto").
377+
* Trait for shared param solver.
378378
*/
379379
private[ml] trait HasSolver extends Params {
380380

381381
/**
382-
* Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.
382+
* Param for the solver algorithm for optimization.
383383
* @group param
384384
*/
385-
final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'")
386-
387-
setDefault(solver, "auto")
385+
val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization")
388386

389387
/** @group getParam */
390388
final def getSolver: String = $(solver)

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
164164
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
165165
}
166166

167-
import GeneralizedLinearRegression._
167+
/**
168+
* The solver algorithm for optimization.
169+
* Supported options: "irls" (iteratively reweighted least squares).
170+
* Default: "irls"
171+
*
172+
* @group param
173+
*/
174+
@Since("2.3.0")
175+
final override val solver: Param[String] = new Param[String](this, "solver",
176+
"The solver algorithm for optimization. Supported options: " +
177+
s"${supportedSolvers.mkString(", ")}. (Default irls)",
178+
ParamValidators.inArray[String](supportedSolvers))
168179

169180
@Since("2.0.0")
170181
override def validateAndTransformSchema(
@@ -350,7 +361,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
350361
*/
351362
@Since("2.0.0")
352363
def setSolver(value: String): this.type = set(solver, value)
353-
setDefault(solver -> "irls")
364+
setDefault(solver -> IRLS)
354365

355366
/**
356367
* Sets the link prediction (linear predictor) column name.
@@ -442,6 +453,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
442453
Gamma -> Inverse, Gamma -> Identity, Gamma -> Log
443454
)
444455

456+
/** String name for "irls" (iteratively reweighted least squares) solver. */
457+
private[regression] val IRLS = "irls"
458+
459+
/** Set of solvers that GeneralizedLinearRegression supports. */
460+
private[regression] val supportedSolvers = Array(IRLS)
461+
445462
/** Set of family names that GeneralizedLinearRegression supports. */
446463
private[regression] lazy val supportedFamilyNames =
447464
supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie"

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares
3434
import org.apache.spark.ml.PredictorParams
3535
import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator
3636
import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
37-
import org.apache.spark.ml.param.ParamMap
37+
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
3838
import org.apache.spark.ml.param.shared._
3939
import org.apache.spark.ml.util._
4040
import org.apache.spark.mllib.evaluation.RegressionMetrics
@@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel
5353
private[regression] trait LinearRegressionParams extends PredictorParams
5454
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
5555
with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
56-
with HasAggregationDepth
56+
with HasAggregationDepth {
57+
58+
import LinearRegression._
59+
60+
/**
61+
* The solver algorithm for optimization.
62+
* Supported options: "l-bfgs", "normal" and "auto".
63+
* Default: "auto"
64+
*
65+
* @group param
66+
*/
67+
@Since("2.3.0")
68+
final override val solver: Param[String] = new Param[String](this, "solver",
69+
"The solver algorithm for optimization. Supported options: " +
70+
s"${supportedSolvers.mkString(", ")}. (Default auto)",
71+
ParamValidators.inArray[String](supportedSolvers))
72+
}
5773

5874
/**
5975
* Linear regression.
@@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
7894
extends Regressor[Vector, LinearRegression, LinearRegressionModel]
7995
with LinearRegressionParams with DefaultParamsWritable with Logging {
8096

97+
import LinearRegression._
98+
8199
@Since("1.4.0")
82100
def this() = this(Identifiable.randomUID("linReg"))
83101

@@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
175193
* @group setParam
176194
*/
177195
@Since("1.6.0")
178-
def setSolver(value: String): this.type = {
179-
require(Set("auto", "l-bfgs", "normal").contains(value),
180-
s"Solver $value was not supported. Supported options: auto, l-bfgs, normal")
181-
set(solver, value)
182-
}
183-
setDefault(solver -> "auto")
196+
def setSolver(value: String): this.type = set(solver, value)
197+
setDefault(solver -> AUTO)
184198

185199
/**
186200
* Suggested depth for treeAggregate (greater than or equal to 2).
@@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
210224
elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth)
211225
instr.logNumFeatures(numFeatures)
212226

213-
if (($(solver) == "auto" &&
214-
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") {
227+
if (($(solver) == AUTO &&
228+
numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) {
215229
// For low dimensional data, WeightedLeastSquares is more efficient since the
216230
// training algorithm only requires one pass through the data. (SPARK-10668)
217231

@@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] {
444458
*/
445459
@Since("2.1.0")
446460
val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
461+
462+
/** String name for "auto". */
463+
private[regression] val AUTO = "auto"
464+
465+
/** String name for "normal". */
466+
private[regression] val NORMAL = "normal"
467+
468+
/** String name for "l-bfgs". */
469+
private[regression] val LBFGS = "l-bfgs"
470+
471+
/** Set of solvers that LinearRegression supports. */
472+
private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS)
447473
}
448474

449475
/**

python/pyspark/ml/classification.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,8 +1265,8 @@ def theta(self):
12651265

12661266
@inherit_doc
12671267
class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
1268-
HasMaxIter, HasTol, HasSeed, HasStepSize, JavaMLWritable,
1269-
JavaMLReadable):
1268+
HasMaxIter, HasTol, HasSeed, HasStepSize, HasSolver,
1269+
JavaMLWritable, JavaMLReadable):
12701270
"""
12711271
Classifier trainer based on the Multilayer Perceptron.
12721272
Each layer has sigmoid activation function, output layer has softmax.
@@ -1407,20 +1407,6 @@ def getStepSize(self):
14071407
"""
14081408
return self.getOrDefault(self.stepSize)
14091409

1410-
@since("2.0.0")
1411-
def setSolver(self, value):
1412-
"""
1413-
Sets the value of :py:attr:`solver`.
1414-
"""
1415-
return self._set(solver=value)
1416-
1417-
@since("2.0.0")
1418-
def getSolver(self):
1419-
"""
1420-
Gets the value of solver or its default value.
1421-
"""
1422-
return self.getOrDefault(self.solver)
1423-
14241410
@since("2.0.0")
14251411
def setInitialWeights(self, value):
14261412
"""

python/pyspark/ml/regression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
9595
.. versionadded:: 1.4.0
9696
"""
9797

98+
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
99+
"options: auto, normal, l-bfgs.", typeConverter=TypeConverters.toString)
100+
98101
@keyword_only
99102
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
100103
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
@@ -1371,6 +1374,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
13711374
linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " +
13721375
"Only applicable to the Tweedie family.",
13731376
typeConverter=TypeConverters.toFloat)
1377+
solver = Param(Params._dummy(), "solver", "The solver algorithm for optimization. Supported " +
1378+
"options: irls.", typeConverter=TypeConverters.toString)
13741379

13751380
@keyword_only
13761381
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",

0 commit comments

Comments
 (0)