Skip to content

Commit 544b4d0

Browse files
committed
[SPARK-21241][MLlib]- Add setIntercept to StreamingLinearRegressionWithSGD in Pyspark.
StreamingLinearRegressionWithSGD class in PySpark is missing the setIntercept Method which offers the possibility to turn on/off the intercept value. API parity is not respected between Python and Scala. We add the setIntercept Method to StreamingLinearRegressionWithSGD class which calls setIntercept Method in LinearRegressionModel class in order to turn on/off the intercept. A big thanks to Matthieu CANEILL for his precious help in solving the issue. This patch was tested by running all tests with ./dev/run-tests and by manual tests.
1 parent 9ce714d commit 544b4d0

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

python/pyspark/mllib/regression.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ def load(cls, sc, path):
199199
model = LinearRegressionModel(weights, intercept)
200200
return model
201201

202+
@since("2.3.0")
203+
def setIntercept(self, intercept):
204+
self._intercept = intercept
205+
202206

203207
# train_func should take two parameters, namely data and initial_weights, and
204208
# return the result of a call to the appropriate JVM stub.
@@ -795,6 +799,11 @@ def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, conver
795799
super(StreamingLinearRegressionWithSGD, self).__init__(
796800
model=self._model)
797801

802+
@since("2.3.0")
803+
def setIntercept(self, intercept):
804+
"""Set if the algorithm should add an intercept"""
805+
self._model.setIntercept(intercept)
806+
798807
@since("1.5.0")
799808
def setInitialWeights(self, initialWeights):
800809
"""

0 commit comments

Comments
 (0)