diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 486043e8d9741..40a3a26114da4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -247,14 +247,35 @@ class LogisticRegression @Since("1.2.0") ( @Since("1.5.0") override def getThresholds: Array[Double] = super.getThresholds - override protected def train(dataset: DataFrame): LogisticRegressionModel = { - // Extract columns from data. If dataset is persisted, do not persist oldDataset. + private var optInitialWeights: Option[Vector] = None + /** @group setParam */ + private[spark] def setInitialWeights(value: Vector): this.type = { + this.optInitialWeights = Some(value) + this + } + + /** Validate the initial weights, return an Option, if not the expected size return None + * and log a warning. + */ + private def validateWeights(vectorOpt: Option[Vector], numFeatures: Int): Option[Vector] = { + vectorOpt.flatMap(vec => + if (vec.size == numFeatures) { + Some(vec) + } else { + logWarning( + s"""Initial weights provided (${vec})did not match the expected size ${numFeatures}""") + None + }) + } + + override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + val instances = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -322,10 +343,12 @@ class LogisticRegression @Since("1.2.0") ( new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - val initialCoefficientsWithIntercept = - Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) + val numFeaturesWithIntercept = if ($(fitIntercept)) numFeatures + 1 else numFeatures + val userSuppliedWeights = validateWeights(optInitialWeights, numFeaturesWithIntercept) + val initialCoefficientsWithIntercept = userSuppliedWeights.getOrElse( + Vectors.zeros(numFeaturesWithIntercept)) - if ($(fitIntercept)) { + if ($(fitIntercept) && !userSuppliedWeights.isDefined) { /* For binary logistic regression, when we initialize the coefficients as zeros, it will converge faster if we initialize the intercept such that diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2d52abc122bf2..f41b7c6318c86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -21,13 +21,15 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.linalg.BLAS.dot -import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} +import org.apache.spark.mllib.util.MLUtils.appendBias import org.apache.spark.rdd.RDD - +import org.apache.spark.sql.SQLContext +import org.apache.spark.storage.StorageLevel /** * Classification model trained using Multinomial/Binary Logistic Regression. @@ -332,6 +334,13 @@ object LogisticRegressionWithSGD { * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * for k classes multi-label classification problem. + * + * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization + * penalty to all elements including the intercept. If this is called with one of + * standard updaters (L1Updater, or SquaredL2Updater) this is translated + * into a call to ml.LogisticRegression, otherwise this will use the existing mllib + * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the + * intercept. */ @Since("1.1.0") class LogisticRegressionWithLBFGS @@ -374,4 +383,83 @@ class LogisticRegressionWithLBFGS new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) } } + + + /** + * Run the algorithm with the configured parameters on an input RDD + * of LabeledPoint entries starting from the initial weights provided. + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * If using ml implementation, uses ml code to generate initial weights. + */ + override def run(input: RDD[LabeledPoint]): LogisticRegressionModel = { + run(input, generateInitialWeights(input), false) + } + + /** + * Run the algorithm with the configured parameters on an input RDD + * of LabeledPoint entries starting from the initial weights provided. + * If a known updater is used calls the ml implementation, to avoid + * applying a regularization penalty to the intercept, otherwise + * defaults to the mllib implementation. If more than two classes + * or feature scaling is disabled, always uses mllib implementation. + * Uses user provided weights. + */ + override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = { + run(input, initialWeights, true) + } + + private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean): + LogisticRegressionModel = { + // ml's Logisitic regression only supports binary classifcation currently. + if (numOfLinearPredictor == 1 && useFeatureScaling) { + def runWithMlLogisitcRegression(elasticNetParam: Double) = { + // Prepare the ml LogisticRegression based on our settings + val lr = new org.apache.spark.ml.classification.LogisticRegression() + lr.setRegParam(optimizer.getRegParam()) + lr.setElasticNetParam(elasticNetParam) + if (userSuppliedWeights) { + val initialWeightsWithIntercept = if (addIntercept) { + appendBias(initialWeights) + } else { + initialWeights + } + lr.setInitialWeights(initialWeightsWithIntercept) + } + lr.setFitIntercept(addIntercept) + lr.setMaxIter(optimizer.getNumIterations()) + lr.setTol(optimizer.getConvergenceTol()) + // Convert our input into a DataFrame + val sqlContext = new SQLContext(input.context) + import sqlContext.implicits._ + val df = input.toDF() + // Determine if we should cache the DF + val handlePersistence = input.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + df.persist(StorageLevel.MEMORY_AND_DISK) + } + // Train our model + val mlLogisticRegresionModel = lr.train(df) + // unpersist if we persisted + if (handlePersistence) { + df.unpersist() + } + // convert the model + val weights = mlLogisticRegresionModel.weights match { + case x: DenseVector => x + case y: Vector => Vectors.dense(y.toArray) + } + createModel(weights, mlLogisticRegresionModel.intercept) + } + optimizer.getUpdater() match { + case x: SquaredL2Updater => runWithMlLogisitcRegression(1.0) + case x: L1Updater => runWithMlLogisitcRegression(0.0) + case _ => super.run(input, initialWeights) + } + } else { + super.run(input, initialWeights) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index efedc112d380e..a5bd77e6bee91 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -69,6 +69,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /* + * Get the convergence tolerance of iterations. + */ + private[mllib] def getConvergenceTol(): Double = { + this.convergenceTol + } + /** * Set the maximal number of iterations for L-BFGS. Default 100. * @deprecated use [[LBFGS#setNumIterations]] instead @@ -86,6 +93,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Get the maximum number of iterations for L-BFGS. Defaults to 100. + */ + private[mllib] def getNumIterations(): Int = { + this.maxNumIterations + } + /** * Set the regularization parameter. Default 0.0. */ @@ -94,6 +108,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Get the regularization parameter. + */ + private[mllib] def getRegParam(): Double = { + this.regParam + } + /** * Set the gradient function (of the loss function of one single data example) * to be used for L-BFGS. @@ -113,6 +134,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) this } + /** + * Returns the updater, limited to internal use. + */ + private[mllib] def getUpdater(): Updater = { + updater + } + override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { val (weights, _) = LBFGS.runLBFGS( data, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 8f657bfb9c730..9cec0d5310450 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -140,7 +140,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * translated back to resulting model weights, so it's transparent to users. * Note: This technique is used in both libsvm and glmnet packages. Default false. */ - private var useFeatureScaling = false + private[mllib] var useFeatureScaling = false /** * The dimension of training features. @@ -196,12 +196,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } /** - * Run the algorithm with the configured parameters on an input - * RDD of LabeledPoint entries. - * + * Generate the initial weights when the user does not supply them */ - @Since("0.8.0") - def run(input: RDD[LabeledPoint]): M = { + protected def generateInitialWeights(input: RDD[LabeledPoint]): Vector = { if (numFeatures < 0) { numFeatures = input.map(_.features.size).first() } @@ -217,16 +214,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always * have the intercept as part of weights to have consistent design. */ - val initialWeights = { - if (numOfLinearPredictor == 1) { - Vectors.zeros(numFeatures) - } else if (addIntercept) { - Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) - } else { - Vectors.zeros(numFeatures * numOfLinearPredictor) - } + if (numOfLinearPredictor == 1) { + Vectors.zeros(numFeatures) + } else if (addIntercept) { + Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) + } else { + Vectors.zeros(numFeatures * numOfLinearPredictor) } - run(input, initialWeights) + } + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + * + */ + @Since("0.8.0") + def run(input: RDD[LabeledPoint]): M = { + run(input, generateInitialWeights(input)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 5ea71c5317b7a..e60197b34ec40 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -168,7 +168,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid setMaxIter(1) - override protected def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { val labelSchema = dataset.schema($(labelCol)) // check for label attribute propagation. assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 8d14bb6572155..864c19103033c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater, Updater, LBFGS, LogisticGradient} import org.apache.spark.util.Utils @@ -215,6 +216,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w // Test if we can correctly learn A, B where Y = logistic(A + B*X) test("logistic regression with LBFGS") { + val updaters: List[Updater] = List(new SquaredL2Updater(), new L1Updater()) + updaters.foreach(testLBFGS) + } + + private def testLBFGS(myUpdater: Updater): Unit = { val nPoints = 10000 val A = 2.0 val B = -1.5 @@ -223,7 +229,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val lr = new LogisticRegressionWithLBFGS().setIntercept(true) + + // Override the updater + class LogisticRegressionWithLBFGSCustomUpdater + extends LogisticRegressionWithLBFGS { + override val optimizer = + new LBFGS(new LogisticGradient, myUpdater) + } + + val lr = new LogisticRegressionWithLBFGSCustomUpdater().setIntercept(true) val model = lr.run(testRDD)