From bd9663cc078289b2676c6bbf38a5e973bd5b82cc Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Tue, 31 Mar 2015 11:19:59 -0700 Subject: [PATCH 1/6] [MLLIB] Add fit intercept api to ml logisticregression I have the fit intercept enabled by default for logistic regression, I wonder what others think here. I understand that it enables allocation by default which is undesirable, but one needs to have a very strong reason for not having an intercept term enabled so it is the safer default from a statistical sense. Explicitly modeling the intercept by adding a column of all 1s does not work. I believe the reason is that since the API for LogisticRegressionWithLBFGS forces column normalization, and a column of all 1s has 0 variance so dividing by 0 kills it. --- .../spark/ml/classification/LogisticRegression.scala | 7 ++++++- .../org/apache/spark/ml/param/sharedParams.scala | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21f61d80dd95..43b95199daea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasThreshold + with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold /** @@ -46,6 +46,7 @@ class LogisticRegression with LogisticRegressionParams { setRegParam(0.1) + setFitIntercept(true) setMaxIter(100) setThreshold(0.5) @@ -55,6 +56,9 @@ class LogisticRegression /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) + /** @group setParam */ + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) @@ -71,6 +75,7 @@ class LogisticRegression lr.optimizer .setRegParam(paramMap(regParam)) .setNumIterations(paramMap(maxIter)) + .addIntercept(paramMap(fitIntercept)) val oldModel = lr.run(oldDataset) val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index 5d660d1e151a..ffe101107cf6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -106,6 +106,17 @@ private[ml] trait HasProbabilityCol extends Params { def getProbabilityCol: String = get(probabilityCol) } +private[ml] trait HasFitIntercept extends Params { + /** + * param for fitting the intercept term + * @group param + */ + val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "fits the intercept term or not") + + /** @group getParam */ + def getFitIntercept: Boolean = get(fitIntercept) +} + private[ml] trait HasThreshold extends Params { /** * param for threshold in (binary) prediction From 329c1e20980036c26b16b0e220dd5156740847f7 Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Fri, 3 Apr 2015 13:43:22 -0700 Subject: [PATCH 2/6] [MLLIB] Add fit intercept term Added unit tests and changed docs in line with PR comments --- .../scala/org/apache/spark/ml/param/sharedParams.scala | 3 ++- .../ml/classification/LogisticRegressionSuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index ffe101107cf6..b194be4847aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -111,7 +111,8 @@ private[ml] trait HasFitIntercept extends Params { * param for fitting the intercept term * @group param */ - val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", "fits the intercept term or not") + val fitIntercept: BooleanParam = + new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term") /** @group getParam */ def getFitIntercept: Boolean = get(fitIntercept) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index b3d1bfcfbee0..154a75a738c1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(lr.getPredictionCol == "prediction") assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getProbabilityCol == "probability") + assert(lr.getFitIntercept == true) val model = lr.fit(dataset) model.transform(dataset) .select("label", "probability", "prediction", "rawPrediction") @@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") + assert(model.intercept != 0.0) + } + + test("logistic regression doesn't fit intercept when fitIntercept is off") { + val lr = new LogisticRegression + lr.setFitIntercept(false) + val model = lr.fit(dataset) + assert(model.intercept == 0.0) } test("logistic regression with setters") { From 2257fcaad897faee76f4034445591a0b19c30ab1 Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Fri, 3 Apr 2015 14:41:55 -0700 Subject: [PATCH 3/6] [MLLIB] Add fitIntercept param to logistic regression Made the trait default true Changed float comparisons to === in unit tests --- .../apache/spark/ml/classification/LogisticRegression.scala | 1 - .../main/scala/org/apache/spark/ml/param/sharedParams.scala | 2 +- .../spark/ml/classification/LogisticRegressionSuite.scala | 4 ++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 43b95199daea..ca42b94705db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -46,7 +46,6 @@ class LogisticRegression with LogisticRegressionParams { setRegParam(0.1) - setFitIntercept(true) setMaxIter(100) setThreshold(0.5) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index b194be4847aa..c681908a166d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -112,7 +112,7 @@ private[ml] trait HasFitIntercept extends Params { * @group param */ val fitIntercept: BooleanParam = - new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term") + new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true)) /** @group getParam */ def getFitIntercept: Boolean = get(fitIntercept) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 154a75a738c1..35d8c2e16c6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -56,14 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getPredictionCol == "prediction") assert(model.getRawPredictionCol == "rawPrediction") assert(model.getProbabilityCol == "probability") - assert(model.intercept != 0.0) + assert(model.intercept !== 0.0) } test("logistic regression doesn't fit intercept when fitIntercept is off") { val lr = new LogisticRegression lr.setFitIntercept(false) val model = lr.fit(dataset) - assert(model.intercept == 0.0) + assert(model.intercept === 0.0) } test("logistic regression with setters") { From 9963509a1ebc559709d2cb48d8b7269a27f11553 Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Fri, 3 Apr 2015 14:45:54 -0700 Subject: [PATCH 4/6] [MLLIB] Add fitIntercept to LogisticRegression Forgot to update doc --- .../src/main/scala/org/apache/spark/ml/param/sharedParams.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index c681908a166d..0739fdbfcbaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -108,7 +108,7 @@ private[ml] trait HasProbabilityCol extends Params { private[ml] trait HasFitIntercept extends Params { /** - * param for fitting the intercept term + * param for fitting the intercept term, defaults to true * @group param */ val fitIntercept: BooleanParam = From 1d6bd6f8457cd8081c6d3e7632621e18e8372cec Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Fri, 3 Apr 2015 16:36:04 -0700 Subject: [PATCH 5/6] [SPARK-6705][MLLIB] Add a fit intercept term to ML LogisticRegression From 9f1286bb3acbc4b3149b42cde2148568a6e300be Mon Sep 17 00:00:00 2001 From: Omede Firouz Date: Sat, 4 Apr 2015 15:25:25 -0700 Subject: [PATCH 6/6] [SPARK-6705][MLLIB] Add fitInterceptTerm to LogisticRegression Whoops, add this to the logisticRegression and not the optimizer --- .../apache/spark/ml/classification/LogisticRegression.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ca42b94705db..ab58fd3d5331 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -70,11 +70,11 @@ class LogisticRegression } // Train model - val lr = new LogisticRegressionWithLBFGS + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(paramMap(fitIntercept)) lr.optimizer .setRegParam(paramMap(regParam)) .setNumIterations(paramMap(maxIter)) - .addIntercept(paramMap(fitIntercept)) val oldModel = lr.run(oldDataset) val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)