diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 60bdeedd6a14..f43780b35c0d 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -862,6 +862,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> model_path = temp_path + "/rfc_model" >>> model.save(model_path) >>> model2 = RandomForestClassificationModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True @@ -997,6 +999,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> model_path = temp_path + "gbtc_model" >>> model.save(model_path) >>> model2 = GBTClassificationModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True >>> model.treeWeights == model2.treeWeights diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 2d17f95b0c44..1434924e0d50 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -785,6 +785,12 @@ def getNumTrees(self): """Number of trees in ensemble.""" return self._call_java("getNumTrees") + @property + @since("2.3.0") + def getMaxDepth(self): + """Maximum depth of the tree (>= 0).""" + return self._call_java("getMaxDepth") + @property @since("1.5.0") def treeWeights(self): @@ -877,6 +883,8 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> model_path = temp_path + "/rfr_model" >>> model.save(model_path) >>> model2 = RandomForestRegressionModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True @@ -996,6 +1004,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, >>> model_path = temp_path + "gbtr_model" >>> model.save(model_path) >>> model2 = GBTRegressionModel.load(model_path) + >>> model.getMaxDepth() == model2.getMaxDepth() + True >>> model.featureImportances == model2.featureImportances True >>> model.treeWeights == model2.treeWeights