Skip to content

Commit 6b8cb1f

Browse files
committed
[SPARK-17197][ML][PYSPARK] PySpark LiR/LoR supports tree aggregation level configurable.
## What changes were proposed in this pull request? [SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090) makes tree aggregation level in LiR/LoR configurable, this PR makes PySpark support this function. ## How was this patch tested? Since ```aggregationDepth``` is an expert param, I'm not prefer to test it in doctest which is also used for example. Here is the offline test result: ![image](https://cloud.githubusercontent.com/assets/1962026/17879457/f83d7760-68a6-11e6-9936-d0a884d5d6ec.png) Author: Yanbo Liang <[email protected]> Closes apache#14766 from yanboliang/spark-17197.
1 parent e0b20f9 commit 6b8cb1f

File tree

4 files changed

+42
-11
lines changed

4 files changed

+42
-11
lines changed

python/pyspark/ml/classification.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def numClasses(self):
6464
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
6565
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
6666
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
67-
HasWeightCol, JavaMLWritable, JavaMLReadable):
67+
HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable):
6868
"""
6969
Logistic regression.
7070
Currently, this class only supports binary classification.
@@ -121,12 +121,14 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
121121
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
122122
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
123123
threshold=0.5, thresholds=None, probabilityCol="probability",
124-
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
124+
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
125+
aggregationDepth=2):
125126
"""
126127
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
127128
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
128129
threshold=0.5, thresholds=None, probabilityCol="probability", \
129-
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
130+
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
131+
aggregationDepth=2)
130132
If the threshold and thresholds Params are both set, they must be equivalent.
131133
"""
132134
super(LogisticRegression, self).__init__()
@@ -142,12 +144,14 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
142144
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
143145
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
144146
threshold=0.5, thresholds=None, probabilityCol="probability",
145-
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
147+
rawPredictionCol="rawPrediction", standardization=True, weightCol=None,
148+
aggregationDepth=2):
146149
"""
147150
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
148151
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
149152
threshold=0.5, thresholds=None, probabilityCol="probability", \
150-
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
153+
rawPredictionCol="rawPrediction", standardization=True, weightCol=None, \
154+
aggregationDepth=2)
151155
Sets params for logistic regression.
152156
If the threshold and thresholds Params are both set, they must be equivalent.
153157
"""

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def get$Name(self):
147147
("solver", "the solver algorithm for optimization. If this is not set or empty, " +
148148
"default value is 'auto'.", "'auto'", "TypeConverters.toString"),
149149
("varianceCol", "column name for the biased sample variance of prediction.",
150-
None, "TypeConverters.toString")]
150+
None, "TypeConverters.toString"),
151+
("aggregationDepth", "suggested depth for treeAggregate (>= 2).", "2",
152+
"TypeConverters.toInt")]
151153

152154
code = []
153155
for name, doc, defaultValueStr, typeConverter in shared:

python/pyspark/ml/param/shared.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,30 @@ def getVarianceCol(self):
560560
return self.getOrDefault(self.varianceCol)
561561

562562

563+
class HasAggregationDepth(Params):
564+
"""
565+
Mixin for param aggregationDepth: suggested depth for treeAggregate (>= 2).
566+
"""
567+
568+
aggregationDepth = Param(Params._dummy(), "aggregationDepth", "suggested depth for treeAggregate (>= 2).", typeConverter=TypeConverters.toInt)
569+
570+
def __init__(self):
571+
super(HasAggregationDepth, self).__init__()
572+
self._setDefault(aggregationDepth=2)
573+
574+
def setAggregationDepth(self, value):
575+
"""
576+
Sets the value of :py:attr:`aggregationDepth`.
577+
"""
578+
return self._set(aggregationDepth=value)
579+
580+
def getAggregationDepth(self):
581+
"""
582+
Gets the value of aggregationDepth or its default value.
583+
"""
584+
return self.getOrDefault(self.aggregationDepth)
585+
586+
563587
class DecisionTreeParams(Params):
564588
"""
565589
Mixin for Decision Tree parameters.

python/pyspark/ml/regression.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
@inherit_doc
4040
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
4141
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
42-
HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
42+
HasStandardization, HasSolver, HasWeightCol, HasAggregationDepth,
43+
JavaMLWritable, JavaMLReadable):
4344
"""
4445
Linear regression.
4546
@@ -97,11 +98,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
9798
@keyword_only
9899
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
99100
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
100-
standardization=True, solver="auto", weightCol=None):
101+
standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
101102
"""
102103
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
103104
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
104-
standardization=True, solver="auto", weightCol=None)
105+
standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
105106
"""
106107
super(LinearRegression, self).__init__()
107108
self._java_obj = self._new_java_obj(
@@ -114,11 +115,11 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
114115
@since("1.4.0")
115116
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
116117
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
117-
standardization=True, solver="auto", weightCol=None):
118+
standardization=True, solver="auto", weightCol=None, aggregationDepth=2):
118119
"""
119120
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
120121
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
121-
standardization=True, solver="auto", weightCol=None)
122+
standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
122123
Sets params for linear regression.
123124
"""
124125
kwargs = self.setParams._input_kwargs

0 commit comments

Comments
 (0)