-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18518][ML] HasSolver supports override #16028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
b907314
39ea8e1
6bb1daf
ebfa9c0
d15ea65
51cb43b
2e280d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,13 +27,16 @@ import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} | |
| import org.apache.spark.ml.feature.LabeledPoint | ||
| import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol} | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql.Dataset | ||
|
|
||
| /** Params for Multilayer Perceptron. */ | ||
| private[classification] trait MultilayerPerceptronParams extends PredictorParams | ||
| with HasSeed with HasMaxIter with HasTol with HasStepSize { | ||
| with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver { | ||
|
|
||
| import MultilayerPerceptronClassifier._ | ||
|
|
||
| /** | ||
| * Layer sizes including input size and output size. | ||
| * | ||
|
|
@@ -75,17 +78,13 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams | |
| * Supported options: "gd" (minibatch gradient descent) or "l-bfgs". | ||
| * Default: "l-bfgs" | ||
| * | ||
| * @group expertParam | ||
| * @group param | ||
| */ | ||
| @Since("2.0.0") | ||
| final val solver: Param[String] = new Param[String](this, "solver", | ||
| final override val solver: Param[String] = new Param[String](this, "solver", | ||
| "The solver algorithm for optimization. Supported options: " + | ||
| s"${MultilayerPerceptronClassifier.supportedSolvers.mkString(", ")}. (Default l-bfgs)", | ||
| ParamValidators.inArray[String](MultilayerPerceptronClassifier.supportedSolvers)) | ||
|
|
||
| /** @group expertGetParam */ | ||
| @Since("2.0.0") | ||
| final def getSolver: String = $(solver) | ||
| s"${supportedSolvers.mkString(", ")}. (Default l-bfgs)", | ||
| ParamValidators.inArray[String](supportedSolvers)) | ||
|
|
||
| /** | ||
| * The initial weights of the model. | ||
|
|
@@ -101,7 +100,7 @@ private[classification] trait MultilayerPerceptronParams extends PredictorParams | |
| final def getInitialWeights: Vector = $(initialWeights) | ||
|
|
||
| setDefault(maxIter -> 100, tol -> 1e-6, blockSize -> 128, | ||
| solver -> MultilayerPerceptronClassifier.LBFGS, stepSize -> 0.03) | ||
| solver -> LBFGS, stepSize -> 0.03) | ||
|
||
| } | ||
|
|
||
| /** Label to vector converter. */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -143,7 +143,18 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam | |
| isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty | ||
| } | ||
|
|
||
| import GeneralizedLinearRegression._ | ||
| /** | ||
| * The solver algorithm for optimization. | ||
| * Supported options: "irls" (iteratively reweighted least squares). | ||
| * Default: "irls" | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2.3.0 -> 2.0.0, please fix it in #17995 . |
||
| final override val solver: Param[String] = new Param[String](this, "solver", | ||
| "The solver algorithm for optimization. Supported options: " + | ||
| s"${supportedSolvers.mkString(", ")}. (Default irls)", | ||
| ParamValidators.inArray[String](supportedSolvers)) | ||
|
|
||
| @Since("2.0.0") | ||
| override def validateAndTransformSchema( | ||
|
|
@@ -314,7 +325,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val | |
| */ | ||
| @Since("2.0.0") | ||
| def setSolver(value: String): this.type = set(solver, value) | ||
| setDefault(solver -> "irls") | ||
| setDefault(solver -> IRLS) | ||
|
|
||
| /** | ||
| * Sets the link prediction (linear predictor) column name. | ||
|
|
@@ -400,6 +411,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine | |
| Gamma -> Inverse, Gamma -> Identity, Gamma -> Log | ||
| ) | ||
|
|
||
| /** String name for "irls" (iteratively reweighted least squares) solver. */ | ||
| private[regression] val IRLS = "irls" | ||
|
|
||
| /** Set of solvers that GeneralizedLinearRegression supports. */ | ||
| private[regression] val supportedSolvers = Array(IRLS) | ||
|
|
||
| /** Set of family names that GeneralizedLinearRegression supports. */ | ||
| private[regression] lazy val supportedFamilyNames = | ||
| supportedFamilyAndLinkPairs.map(_._1.name).toArray :+ "tweedie" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,7 @@ import org.apache.spark.ml.optim.WeightedLeastSquares | |
| import org.apache.spark.ml.PredictorParams | ||
| import org.apache.spark.ml.optim.aggregator.LeastSquaresAggregator | ||
| import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction} | ||
| import org.apache.spark.ml.param.ParamMap | ||
| import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} | ||
| import org.apache.spark.ml.param.shared._ | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.mllib.evaluation.RegressionMetrics | ||
|
|
@@ -53,7 +53,23 @@ import org.apache.spark.storage.StorageLevel | |
| private[regression] trait LinearRegressionParams extends PredictorParams | ||
| with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol | ||
| with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver | ||
| with HasAggregationDepth | ||
| with HasAggregationDepth { | ||
|
|
||
| import LinearRegression._ | ||
|
|
||
| /** | ||
| * The solver algorithm for optimization. | ||
| * Supported options: "l-bfgs", "normal" and "auto". | ||
| * Default: "auto" | ||
| * | ||
| * @group expertParam | ||
| */ | ||
| @Since("2.3.0") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 2.3.0 -> 1.6.0 |
||
| final override val solver: Param[String] = new Param[String](this, "solver", | ||
| "The solver algorithm for optimization. Supported options: " + | ||
| s"${supportedSolvers.mkString(", ")}. (Default auto)", | ||
| ParamValidators.inArray[String](supportedSolvers)) | ||
| } | ||
|
|
||
| /** | ||
| * Linear regression. | ||
|
|
@@ -78,6 +94,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| extends Regressor[Vector, LinearRegression, LinearRegressionModel] | ||
| with LinearRegressionParams with DefaultParamsWritable with Logging { | ||
|
|
||
| import LinearRegression._ | ||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("linReg")) | ||
|
|
||
|
|
@@ -175,12 +193,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| * @group setParam | ||
| */ | ||
| @Since("1.6.0") | ||
| def setSolver(value: String): this.type = { | ||
| require(Set("auto", "l-bfgs", "normal").contains(value), | ||
| s"Solver $value was not supported. Supported options: auto, l-bfgs, normal") | ||
| set(solver, value) | ||
| } | ||
| setDefault(solver -> "auto") | ||
| def setSolver(value: String): this.type = set(solver, value) | ||
| setDefault(solver -> AUTO) | ||
|
|
||
| /** | ||
| * Suggested depth for treeAggregate (greater than or equal to 2). | ||
|
|
@@ -210,8 +224,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
| elasticNetParam, fitIntercept, maxIter, regParam, standardization, aggregationDepth) | ||
| instr.logNumFeatures(numFeatures) | ||
|
|
||
| if (($(solver) == "auto" && | ||
| numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { | ||
| if (($(solver) == AUTO && | ||
| numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == NORMAL) { | ||
| // For low dimensional data, WeightedLeastSquares is more efficient since the | ||
| // training algorithm only requires one pass through the data. (SPARK-10668) | ||
|
|
||
|
|
@@ -444,6 +458,18 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { | |
| */ | ||
| @Since("2.1.0") | ||
| val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES | ||
|
|
||
| /** String name for "auto". */ | ||
| private[regression] val AUTO = "auto" | ||
|
||
|
|
||
| /** String name for "normal". */ | ||
| private[regression] val NORMAL = "normal" | ||
|
|
||
| /** String name for "l-bfgs". */ | ||
| private[regression] val LBFGS = "l-bfgs" | ||
|
|
||
| /** Set of solvers that LinearRegression supports. */ | ||
| private[regression] val supportedSolvers = Array(AUTO, NORMAL, LBFGS) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here should be
expertParam.