Skip to content

Commit 723853e

Browse files
committed
[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml
Otherwise, users can only use `transform` on the models. brkyvz Author: Xiangrui Meng <[email protected]> Closes apache#6156 from mengxr/SPARK-7647 and squashes the following commits: 1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
1 parent b208f99 commit 723853e

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

python/pyspark/ml/classification.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
4343
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
4444
>>> model.transform(test0).head().prediction
4545
0.0
46+
>>> model.weights
47+
DenseVector([5.5...])
48+
>>> model.intercept
49+
-2.68...
4650
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
4751
>>> model.transform(test1).head().prediction
4852
1.0
@@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel):
148152
Model fitted by LogisticRegression.
149153
"""
150154

155+
@property
156+
def weights(self):
157+
"""
158+
Model weights.
159+
"""
160+
return self._call_java("weights")
161+
162+
@property
163+
def intercept(self):
164+
"""
165+
Model intercept.
166+
"""
167+
return self._call_java("intercept")
168+
151169

152170
class TreeClassifierParams(object):
153171
"""

python/pyspark/ml/regression.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
5151
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
5252
>>> model.transform(test0).head().prediction
5353
-1.0
54+
>>> model.weights
55+
DenseVector([1.0])
56+
>>> model.intercept
57+
0.0
5458
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
5559
>>> model.transform(test1).head().prediction
5660
1.0
@@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel):
117121
Model fitted by LinearRegression.
118122
"""
119123

124+
@property
125+
def weights(self):
126+
"""
127+
Model weights.
128+
"""
129+
return self._call_java("weights")
130+
131+
@property
132+
def intercept(self):
133+
"""
134+
Model intercept.
135+
"""
136+
return self._call_java("intercept")
137+
120138

121139
class TreeRegressorParams(object):
122140
"""

python/pyspark/ml/wrapper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pyspark.sql import DataFrame
2222
from pyspark.ml.param import Params
2323
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
24-
from pyspark.mllib.common import inherit_doc
24+
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
2525

2626

2727
def _jvm():
@@ -149,6 +149,12 @@ def __init__(self, java_model):
149149
def _java_obj(self):
150150
return self._java_model
151151

152+
def _call_java(self, name, *args):
153+
m = getattr(self._java_model, name)
154+
sc = SparkContext._active_spark_context
155+
java_args = [_py2java(sc, arg) for arg in args]
156+
return _java2py(sc, m(*java_args))
157+
152158

153159
@inherit_doc
154160
class JavaEvaluator(Evaluator, JavaWrapper):

0 commit comments

Comments
 (0)