@@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
4343 >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
4444 >>> model.transform(test0).head().prediction
4545 0.0
46+ >>> model.weights
47+ DenseVector([5.5...])
48+ >>> model.intercept
49+ -2.68...
4650 >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
4751 >>> model.transform(test1).head().prediction
4852 1.0
@@ -67,7 +71,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
6771 threshold = 0.5 , probabilityCol = "probability" ):
6872 """
6973 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
70- maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
74+ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
7175 threshold=0.5, probabilityCol="probability")
7276 """
7377 super (LogisticRegression , self ).__init__ ()
@@ -92,8 +96,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
9296 maxIter = 100 , regParam = 0.1 , elasticNetParam = 0.0 , tol = 1e-6 , fitIntercept = True ,
9397 threshold = 0.5 , probabilityCol = "probability" ):
9498 """
95- setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
96- maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
99+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
100+ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
97101 threshold=0.5, probabilityCol="probability")
98102 Sets params for logistic regression.
99103 """
@@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel):
148152 Model fitted by LogisticRegression.
149153 """
150154
155+ @property
156+ def weights (self ):
157+ """
158+ Model weights.
159+ """
160+ return self ._call_java ("weights" )
161+
162+ @property
163+ def intercept (self ):
164+ """
165+ Model intercept.
166+ """
167+ return self ._call_java ("intercept" )
168+
151169
152170class TreeClassifierParams (object ):
153171 """
@@ -202,7 +220,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
202220 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 , impurity = "gini" ):
203221 """
204222 __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
205- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
223+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
206224 maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
207225 """
208226 super (DecisionTreeClassifier , self ).__init__ ()
@@ -224,9 +242,8 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
224242 impurity = "gini" ):
225243 """
226244 setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
227- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
228- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
229- impurity="gini")
245+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
246+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
230247 Sets params for the DecisionTreeClassifier.
231248 """
232249 kwargs = self .setParams ._input_kwargs
@@ -302,9 +319,9 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
302319 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 , impurity = "gini" ,
303320 numTrees = 20 , featureSubsetStrategy = "auto" , seed = 42 ):
304321 """
305- __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
306- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
307- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
322+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
323+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
324+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
308325 numTrees=20, featureSubsetStrategy="auto", seed=42)
309326 """
310327 super (RandomForestClassifier , self ).__init__ ()
@@ -337,9 +354,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
337354 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 , seed = 42 ,
338355 impurity = "gini" , numTrees = 20 , featureSubsetStrategy = "auto" ):
339356 """
340- setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
341- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
342- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
357+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
358+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
359+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, \
343360 impurity="gini", numTrees=20, featureSubsetStrategy="auto")
344361 Sets params for linear classification.
345362 """
@@ -453,10 +470,10 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
453470 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 , lossType = "logistic" ,
454471 maxIter = 20 , stepSize = 0.1 ):
455472 """
456- __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
457- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
458- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
459- maxIter=20, stepSize=0.1)
473+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
474+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
475+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
476+ lossType="logistic", maxIter=20, stepSize=0.1)
460477 """
461478 super (GBTClassifier , self ).__init__ ()
462479 #: param for Loss function which GBT tries to minimize (case-insensitive).
@@ -484,9 +501,9 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
484501 maxMemoryInMB = 256 , cacheNodeIds = False , checkpointInterval = 10 ,
485502 lossType = "logistic" , maxIter = 20 , stepSize = 0.1 ):
486503 """
487- setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
488- maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
489- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
504+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
505+ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
506+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
490507 lossType="logistic", maxIter=20, stepSize=0.1)
491508 Sets params for Gradient Boosted Tree Classification.
492509 """
0 commit comments