Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 122 additions & 100 deletions python/pyspark/mllib/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.util import JavaLoader, JavaSaveable
from typing import Dict, Optional, Tuple, Union, overload, TYPE_CHECKING
from pyspark.rdd import RDD

if TYPE_CHECKING:
from pyspark.mllib._typing import VectorLike


__all__ = [
"DecisionTreeModel",
Expand All @@ -40,7 +46,15 @@ class TreeEnsembleModel(JavaModelWrapper, JavaSaveable):
.. versionadded:: 1.3.0
"""

def predict(self, x):
@overload
def predict(self, x: "VectorLike") -> float:
...

@overload
def predict(self, x: RDD["VectorLike"]) -> RDD[float]:
...

def predict(self, x: Union["VectorLike", RDD["VectorLike"]]) -> Union[float, RDD[float]]:
"""
Predict values for a single data point or an RDD of points using
the model trained.
Expand All @@ -60,37 +74,45 @@ def predict(self, x):
return self.call("predict", _convert_to_vector(x))

@since("1.3.0")
def numTrees(self):
def numTrees(self) -> int:
"""
Get number of trees in ensemble.
"""
return self.call("numTrees")

@since("1.3.0")
def totalNumNodes(self):
def totalNumNodes(self) -> int:
"""
Get total number of nodes, summed over all trees in the ensemble.
"""
return self.call("totalNumNodes")

def __repr__(self):
def __repr__(self) -> str:
"""Summary of model"""
return self._java_model.toString()

@since("1.3.0")
def toDebugString(self):
def toDebugString(self) -> str:
"""Full model"""
return self._java_model.toDebugString()


class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader):
class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader["DecisionTreeModel"]):
"""
A decision tree model for classification or regression.

.. versionadded:: 1.1.0
"""

def predict(self, x):
@overload
def predict(self, x: "VectorLike") -> float:
...

@overload
def predict(self, x: RDD["VectorLike"]) -> RDD[float]:
...

def predict(self, x: Union["VectorLike", RDD["VectorLike"]]) -> Union[float, RDD[float]]:
"""
Predict the label of one or more examples.

Expand All @@ -115,29 +137,29 @@ def predict(self, x):
return self.call("predict", _convert_to_vector(x))

@since("1.1.0")
def numNodes(self):
def numNodes(self) -> int:
"""Get number of nodes in tree, including leaf nodes."""
return self._java_model.numNodes()

@since("1.1.0")
def depth(self):
def depth(self) -> int:
"""
Get depth of tree (e.g. depth 0 means 1 leaf node, depth 1
means 1 internal node + 2 leaf nodes).
"""
return self._java_model.depth()

def __repr__(self):
def __repr__(self) -> str:
"""summary of model."""
return self._java_model.toString()

@since("1.2.0")
def toDebugString(self):
def toDebugString(self) -> str:
"""full model."""
return self._java_model.toDebugString()

@classmethod
def _java_loader_class(cls):
def _java_loader_class(cls) -> str:
return "org.apache.spark.mllib.tree.model.DecisionTreeModel"


Expand All @@ -152,16 +174,16 @@ class DecisionTree:
@classmethod
def _train(
cls,
data,
type,
numClasses,
features,
impurity="gini",
maxDepth=5,
maxBins=32,
minInstancesPerNode=1,
minInfoGain=0.0,
):
data: RDD[LabeledPoint],
type: str,
numClasses: int,
features: Dict[int, int],
impurity: str = "gini",
maxDepth: int = 5,
maxBins: int = 32,
minInstancesPerNode: int = 1,
minInfoGain: float = 0.0,
) -> DecisionTreeModel:
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
model = callMLlibFunc(
Expand All @@ -181,15 +203,15 @@ def _train(
@classmethod
def trainClassifier(
cls,
data,
numClasses,
categoricalFeaturesInfo,
impurity="gini",
maxDepth=5,
maxBins=32,
minInstancesPerNode=1,
minInfoGain=0.0,
):
data: RDD[LabeledPoint],
numClasses: int,
categoricalFeaturesInfo: Dict[int, int],
impurity: str = "gini",
maxDepth: int = 5,
maxBins: int = 32,
minInstancesPerNode: int = 1,
minInfoGain: float = 0.0,
) -> DecisionTreeModel:
"""
Train a decision tree model for classification.

Expand Down Expand Up @@ -276,14 +298,14 @@ def trainClassifier(
@since("1.1.0")
def trainRegressor(
cls,
data,
categoricalFeaturesInfo,
impurity="variance",
maxDepth=5,
maxBins=32,
minInstancesPerNode=1,
minInfoGain=0.0,
):
data: RDD[LabeledPoint],
categoricalFeaturesInfo: Dict[int, int],
impurity: str = "variance",
maxDepth: int = 5,
maxBins: int = 32,
minInstancesPerNode: int = 1,
minInfoGain: float = 0.0,
) -> DecisionTreeModel:
"""
Train a decision tree model for regression.

Expand Down Expand Up @@ -354,15 +376,15 @@ def trainRegressor(


@inherit_doc
class RandomForestModel(TreeEnsembleModel, JavaLoader):
class RandomForestModel(TreeEnsembleModel, JavaLoader["RandomForestModel"]):
"""
Represents a random forest model.

.. versionadded:: 1.2.0
"""

@classmethod
def _java_loader_class(cls):
def _java_loader_class(cls) -> str:
return "org.apache.spark.mllib.tree.model.RandomForestModel"


Expand All @@ -374,22 +396,22 @@ class RandomForest:
.. versionadded:: 1.2.0
"""

supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
supportedFeatureSubsetStrategies: Tuple[str, ...] = ("auto", "all", "sqrt", "log2", "onethird")

@classmethod
def _train(
cls,
data,
algo,
numClasses,
categoricalFeaturesInfo,
numTrees,
featureSubsetStrategy,
impurity,
maxDepth,
maxBins,
seed,
):
data: RDD[LabeledPoint],
algo: str,
numClasses: int,
categoricalFeaturesInfo: Dict[int, int],
numTrees: int,
featureSubsetStrategy: str,
impurity: str,
maxDepth: int,
maxBins: int,
seed: Optional[int],
) -> RandomForestModel:
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
if featureSubsetStrategy not in cls.supportedFeatureSubsetStrategies:
Expand All @@ -414,16 +436,16 @@ def _train(
@classmethod
def trainClassifier(
cls,
data,
numClasses,
categoricalFeaturesInfo,
numTrees,
featureSubsetStrategy="auto",
impurity="gini",
maxDepth=4,
maxBins=32,
seed=None,
):
data: RDD[LabeledPoint],
numClasses: int,
categoricalFeaturesInfo: Dict[int, int],
numTrees: int,
featureSubsetStrategy: str = "auto",
impurity: str = "gini",
maxDepth: int = 4,
maxBins: int = 32,
seed: Optional[int] = None,
) -> RandomForestModel:
"""
Train a random forest model for binary or multiclass
classification.
Expand Down Expand Up @@ -530,15 +552,15 @@ def trainClassifier(
@classmethod
def trainRegressor(
cls,
data,
categoricalFeaturesInfo,
numTrees,
featureSubsetStrategy="auto",
impurity="variance",
maxDepth=4,
maxBins=32,
seed=None,
):
data: RDD[LabeledPoint],
categoricalFeaturesInfo: Dict[int, int],
numTrees: int,
featureSubsetStrategy: str = "auto",
impurity: str = "variance",
maxDepth: int = 4,
maxBins: int = 32,
seed: Optional[int] = None,
) -> RandomForestModel:
"""
Train a random forest model for regression.

Expand Down Expand Up @@ -625,15 +647,15 @@ def trainRegressor(


@inherit_doc
class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader):
class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader["GradientBoostedTreesModel"]):
"""
Represents a gradient-boosted tree model.

.. versionadded:: 1.3.0
"""

@classmethod
def _java_loader_class(cls):
def _java_loader_class(cls) -> str:
return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"


Expand All @@ -648,15 +670,15 @@ class GradientBoostedTrees:
@classmethod
def _train(
cls,
data,
algo,
categoricalFeaturesInfo,
loss,
numIterations,
learningRate,
maxDepth,
maxBins,
):
data: RDD[LabeledPoint],
algo: str,
categoricalFeaturesInfo: Dict[int, int],
loss: str,
numIterations: int,
learningRate: float,
maxDepth: int,
maxBins: int,
) -> GradientBoostedTreesModel:
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
model = callMLlibFunc(
Expand All @@ -675,14 +697,14 @@ def _train(
@classmethod
def trainClassifier(
cls,
data,
categoricalFeaturesInfo,
loss="logLoss",
numIterations=100,
learningRate=0.1,
maxDepth=3,
maxBins=32,
):
data: RDD[LabeledPoint],
categoricalFeaturesInfo: Dict[int, int],
loss: str = "logLoss",
numIterations: int = 100,
learningRate: float = 0.1,
maxDepth: int = 3,
maxBins: int = 32,
) -> GradientBoostedTreesModel:
"""
Train a gradient-boosted trees model for classification.

Expand Down Expand Up @@ -765,14 +787,14 @@ def trainClassifier(
@classmethod
def trainRegressor(
cls,
data,
categoricalFeaturesInfo,
loss="leastSquaresError",
numIterations=100,
learningRate=0.1,
maxDepth=3,
maxBins=32,
):
data: RDD[LabeledPoint],
categoricalFeaturesInfo: Dict[int, int],
loss: str = "leastSquaresError",
numIterations: int = 100,
learningRate: float = 0.1,
maxDepth: int = 3,
maxBins: int = 32,
) -> GradientBoostedTreesModel:
"""
Train a gradient-boosted trees model for regression.

Expand Down Expand Up @@ -851,7 +873,7 @@ def trainRegressor(
)


def _test():
def _test() -> None:
import doctest

globs = globals().copy()
Expand Down
Loading