Skip to content

Commit d138aa8

Browse files
Omede Firouzjkbradley
authored andcommitted
[SPARK-6705][MLLIB] Add fit intercept api to ml logisticregression
I have the fit intercept enabled by default for logistic regression, I wonder what others think here. I understand that it enables allocation by default which is undesirable, but one needs to have a very strong reason for not having an intercept term enabled so it is the safer default from a statistical sense. Explicitly modeling the intercept by adding a column of all 1s does not work. I believe the reason is that since the API for LogisticRegressionWithLBFGS forces column normalization, and a column of all 1s has 0 variance so dividing by 0 kills it. Author: Omede Firouz <[email protected]> Closes apache#5301 from oefirouz/addIntercept and squashes the following commits: 9f1286b [Omede Firouz] [SPARK-6705][MLLIB] Add fitInterceptTerm to LogisticRegression 1d6bd6f [Omede Firouz] [SPARK-6705][MLLIB] Add a fit intercept term to ML LogisticRegression 9963509 [Omede Firouz] [MLLIB] Add fitIntercept to LogisticRegression 2257fca [Omede Firouz] [MLLIB] Add fitIntercept param to logistic regression 329c1e2 [Omede Firouz] [MLLIB] Add fit intercept term bd9663c [Omede Firouz] [MLLIB] Add fit intercept api to ml logisticregression
1 parent c83e039 commit d138aa8

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
3131
* Params for logistic regression.
3232
*/
3333
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
34-
with HasRegParam with HasMaxIter with HasThreshold
34+
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
3535

3636

3737
/**
@@ -55,6 +55,9 @@ class LogisticRegression
5555
/** @group setParam */
5656
def setMaxIter(value: Int): this.type = set(maxIter, value)
5757

58+
/** @group setParam */
59+
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
60+
5861
/** @group setParam */
5962
def setThreshold(value: Double): this.type = set(threshold, value)
6063

@@ -67,7 +70,8 @@ class LogisticRegression
6770
}
6871

6972
// Train model
70-
val lr = new LogisticRegressionWithLBFGS
73+
val lr = new LogisticRegressionWithLBFGS()
74+
.setIntercept(paramMap(fitIntercept))
7175
lr.optimizer
7276
.setRegParam(paramMap(regParam))
7377
.setNumIterations(paramMap(maxIter))

mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params {
106106
def getProbabilityCol: String = get(probabilityCol)
107107
}
108108

109+
private[ml] trait HasFitIntercept extends Params {
110+
/**
111+
* param for fitting the intercept term, defaults to true
112+
* @group param
113+
*/
114+
val fitIntercept: BooleanParam =
115+
new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true))
116+
117+
/** @group getParam */
118+
def getFitIntercept: Boolean = get(fitIntercept)
119+
}
120+
109121
private[ml] trait HasThreshold extends Params {
110122
/**
111123
* param for threshold in (binary) prediction

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
4646
assert(lr.getPredictionCol == "prediction")
4747
assert(lr.getRawPredictionCol == "rawPrediction")
4848
assert(lr.getProbabilityCol == "probability")
49+
assert(lr.getFitIntercept == true)
4950
val model = lr.fit(dataset)
5051
model.transform(dataset)
5152
.select("label", "probability", "prediction", "rawPrediction")
@@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
5556
assert(model.getPredictionCol == "prediction")
5657
assert(model.getRawPredictionCol == "rawPrediction")
5758
assert(model.getProbabilityCol == "probability")
59+
assert(model.intercept !== 0.0)
60+
}
61+
62+
test("logistic regression doesn't fit intercept when fitIntercept is off") {
63+
val lr = new LogisticRegression
64+
lr.setFitIntercept(false)
65+
val model = lr.fit(dataset)
66+
assert(model.intercept === 0.0)
5867
}
5968

6069
test("logistic regression with setters") {

0 commit comments

Comments
 (0)