Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
@Since("1.4.0")
class DecisionTreeClassificationModel private[ml] (
@Since("1.4.0")override val uid: String,
@Since("1.4.0")override val rootNode: ClassificationNode,
@Since("1.4.0")override val rootNode: Node,
@Since("1.6.0")override val numFeatures: Int,
@Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
Expand All @@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: ClassificationNode, numFeatures: Int, numClasses: Int) =
private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)

override def predict(features: Vector): Double = {
Expand Down Expand Up @@ -279,9 +279,8 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true)
val model = new DecisionTreeClassificationModel(metadata.uid,
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses)
metadata.getAndSetParams(model)
model
}
Expand All @@ -296,10 +295,9 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
require(oldModel.algo == OldAlgo.Classification,
s"Cannot convert non-classification DecisionTreeModel (old API) to" +
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = true)
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
// Can't infer number of features from old model, so default to -1
new DecisionTreeClassificationModel(uid,
rootNode.asInstanceOf[ClassificationNode], numFeatures, -1)
new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
override def load(path: String): GBTClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ numFeaturesKey).extract[Int]
val numTrees = (metadata.metadata \ numTreesKey).extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,15 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
override def load(path: String): RandomForestClassificationModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, true)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeClassificationModel] = treesData.map {
case (treeMetadata, root) =>
val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
val tree =
new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses)
treeMetadata.getAndSetParams(tree)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
@Since("1.4.0")
class DecisionTreeRegressionModel private[ml] (
override val uid: String,
override val rootNode: RegressionNode,
override val rootNode: Node,
override val numFeatures: Int)
extends PredictionModel[Vector, DecisionTreeRegressionModel]
with DecisionTreeModel with DecisionTreeRegressorParams with MLWritable with Serializable {
Expand All @@ -175,7 +175,7 @@ class DecisionTreeRegressionModel private[ml] (
* Construct a decision tree regression model.
* @param rootNode Root node of tree, with other nodes attached.
*/
private[ml] def this(rootNode: RegressionNode, numFeatures: Int) =
private[ml] def this(rootNode: Node, numFeatures: Int) =
this(Identifiable.randomUID("dtr"), rootNode, numFeatures)

override def predict(features: Vector): Double = {
Expand Down Expand Up @@ -279,9 +279,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false)
val model = new DecisionTreeRegressionModel(metadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val root = loadTreeNodes(path, metadata, sparkSession)
val model = new DecisionTreeRegressionModel(metadata.uid, root, numFeatures)
metadata.getAndSetParams(model)
model
}
Expand All @@ -296,8 +295,8 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
require(oldModel.algo == OldAlgo.Regression,
s"Cannot convert non-regression DecisionTreeModel (old API) to" +
s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures, isClassification = false)
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
new DecisionTreeRegressionModel(uid, rootNode.asInstanceOf[RegressionNode], numFeatures)
new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,15 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
override def load(path: String): GBTRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)

val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map {
case (treeMetadata, root) =>
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,13 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
override def load(path: String): RandomForestRegressionModel = {
implicit val format = DefaultFormats
val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) =
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName, false)
EnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName)
val numFeatures = (metadata.metadata \ "numFeatures").extract[Int]
val numTrees = (metadata.metadata \ "numTrees").extract[Int]

val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
root.asInstanceOf[RegressionNode], numFeatures)
val tree =
new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures)
treeMetadata.getAndSetParams(tree)
tree
}
Expand Down
Loading