From 96e797f79000964ea226efced7318f24a1722535 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 26 May 2017 15:48:02 +0800 Subject: [PATCH 1/2] ENH: add maxDepth for resemble tree model --- python/pyspark/ml/regression.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 2d17f95b0c44..f9816a9fc22c 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -785,6 +785,12 @@ def getNumTrees(self): """Number of trees in ensemble.""" return self._call_java("getNumTrees") + @property + @since("2.3.0") + def getMaxDepth(self): + """Maximum depth of the tree (>= 0).""" + return self._call_java("getMaxDepth") + @property @since("1.5.0") def treeWeights(self): From 50815928f070f89ae71e235d51e833cef596a361 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 26 May 2017 16:45:11 +0800 Subject: [PATCH 2/2] TST: add doctest --- python/pyspark/ml/classification.py | 4 ++++ python/pyspark/ml/regression.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 60bdeedd6a14..f43780b35c0d 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -862,6 +862,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> model_path = temp_path + "/rfc_model" >>> model.save(model_path) >>> model2 = RandomForestClassificationModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True @@ -997,6 +999,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> model_path = temp_path + "gbtc_model" >>> model.save(model_path) >>> model2 = GBTClassificationModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True >>> model.treeWeights == model2.treeWeights diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f9816a9fc22c..1434924e0d50 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -883,6 +883,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> model_path = temp_path + "/rfr_model" >>> model.save(model_path) >>> model2 = RandomForestRegressionModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True @@ -1002,6 +1004,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> model_path = temp_path + "gbtr_model" >>> model.save(model_path) >>> model2 = GBTRegressionModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True >>> model.treeWeights == model2.treeWeights