-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training #2708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
6cc0333
e41d715
822c912
7ad7a71
c41b1b6
eefeef1
8e269ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -532,6 +532,17 @@ object DecisionTree extends Serializable with Logging { | |
| Some(mutableNodeToFeatures.toMap) | ||
| } | ||
|
|
||
| // array of nodes to train indexed by node index in group | ||
| val nodes = { | ||
| val nodes = Array.fill[Node](numNodes)(null) | ||
| nodesForGroup.foreach { case (treeIndex, nodesForTree) => | ||
| nodesForTree.foreach { node => | ||
| nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node | ||
| } | ||
| } | ||
| nodes | ||
| } | ||
|
|
||
| // Calculate best splits for all nodes in the group | ||
| timer.start("chooseSplits") | ||
|
|
||
|
|
@@ -568,7 +579,7 @@ object DecisionTree extends Serializable with Logging { | |
|
|
||
| // find best split for each node | ||
| val (split: Split, stats: InformationGainStats, predict: Predict) = | ||
| binsToBestSplit(aggStats, splits, featuresForNode) | ||
| binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) | ||
| (nodeIndex, (split, stats, predict)) | ||
| }.collectAsMap() | ||
|
|
||
|
|
@@ -587,17 +598,26 @@ object DecisionTree extends Serializable with Logging { | |
| // Extract info for this node. Create children if not leaf. | ||
| val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) | ||
| assert(node.id == nodeIndex) | ||
| node.predict = predict.predict | ||
| node.predict = predict | ||
| node.isLeaf = isLeaf | ||
| node.stats = Some(stats) | ||
| node.impurity = stats.impurity | ||
| logDebug("Node = " + node) | ||
|
|
||
| if (!isLeaf) { | ||
| node.split = Some(split) | ||
| node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) | ||
| node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) | ||
| nodeQueue.enqueue((treeIndex, node.leftNode.get)) | ||
| nodeQueue.enqueue((treeIndex, node.rightNode.get)) | ||
| val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also check stats.leftImpurity and rightImpurity. If stats.leftImpurity = 0, then we know the left child will be a leaf. Same for the right child.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once this is done, it might be good to add 1 more test to make sure it works. A slight modification of the test you already added should work. |
||
| node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), | ||
| stats.leftPredict, stats.leftImpurity, childIsLeaf)) | ||
| node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), | ||
| stats.rightPredict, stats.rightImpurity, childIsLeaf)) | ||
|
|
||
| // enqueue left child and right child if they are not leaves | ||
| if (!childIsLeaf) { | ||
| nodeQueue.enqueue((treeIndex, node.leftNode.get)) | ||
| nodeQueue.enqueue((treeIndex, node.rightNode.get)) | ||
| } | ||
|
|
||
| logDebug("leftChildIndex = " + node.leftNode.get.id + | ||
| ", impurity = " + stats.leftImpurity) | ||
| logDebug("rightChildIndex = " + node.rightNode.get.id + | ||
|
|
@@ -617,7 +637,8 @@ object DecisionTree extends Serializable with Logging { | |
| private def calculateGainForSplit( | ||
| leftImpurityCalculator: ImpurityCalculator, | ||
| rightImpurityCalculator: ImpurityCalculator, | ||
| metadata: DecisionTreeMetadata): InformationGainStats = { | ||
| metadata: DecisionTreeMetadata, | ||
| impurity: Double): InformationGainStats = { | ||
| val leftCount = leftImpurityCalculator.count | ||
| val rightCount = rightImpurityCalculator.count | ||
|
|
||
|
|
@@ -630,11 +651,6 @@ object DecisionTree extends Serializable with Logging { | |
|
|
||
| val totalCount = leftCount + rightCount | ||
|
|
||
| val parentNodeAgg = leftImpurityCalculator.copy | ||
| parentNodeAgg.add(rightImpurityCalculator) | ||
|
|
||
| val impurity = parentNodeAgg.calculate() | ||
|
|
||
| val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 | ||
| val rightImpurity = rightImpurityCalculator.calculate() | ||
|
|
||
|
|
@@ -649,7 +665,18 @@ object DecisionTree extends Serializable with Logging { | |
| return InformationGainStats.invalidInformationGainStats | ||
| } | ||
|
|
||
| new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) | ||
| // calculate left and right predict | ||
| val leftPredict = calculatePredict(leftImpurityCalculator) | ||
| val rightPredict = calculatePredict(rightImpurityCalculator) | ||
|
|
||
| new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, | ||
| leftPredict, rightPredict) | ||
| } | ||
|
|
||
| private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { | ||
| val predict = impurityCalculator.predict | ||
| val prob = impurityCalculator.prob(predict) | ||
| new Predict(predict, prob) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -659,15 +686,16 @@ object DecisionTree extends Serializable with Logging { | |
| * @param rightImpurityCalculator right node aggregates for a split | ||
| * @return predict value for current node | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| */ | ||
| private def calculatePredict( | ||
| private def calculatePredictImpurity( | ||
| leftImpurityCalculator: ImpurityCalculator, | ||
| rightImpurityCalculator: ImpurityCalculator): Predict = { | ||
| rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { | ||
| val parentNodeAgg = leftImpurityCalculator.copy | ||
| parentNodeAgg.add(rightImpurityCalculator) | ||
| val predict = parentNodeAgg.predict | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could make this a little shorter by using calculatePredict() |
||
| val prob = parentNodeAgg.prob(predict) | ||
| val impurity = parentNodeAgg.calculate() | ||
|
|
||
| new Predict(predict, prob) | ||
| (new Predict(predict, prob), impurity) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging { | |
| private def binsToBestSplit( | ||
| binAggregates: DTStatsAggregator, | ||
| splits: Array[Array[Split]], | ||
| featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { | ||
| featuresForNode: Option[Array[Int]], | ||
| node: Node): (Split, InformationGainStats, Predict) = { | ||
|
|
||
| // calculate predict only once | ||
| var predict: Option[Predict] = None | ||
| // calculate predict and impurity if current node are top node | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "if current node are top node" --> "if current node is top node" |
||
| val level = Node.indexToLevel(node.id) | ||
| var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { | ||
| None | ||
| } else { | ||
| Some((node.predict, node.impurity)) | ||
| } | ||
|
|
||
| // For each (feature, split), calculate the gain, and select the best (feature, split). | ||
| val (bestSplit, bestSplitStats) = | ||
|
|
@@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging { | |
| val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) | ||
| val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) | ||
| rightChildStats.subtract(leftChildStats) | ||
| predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) | ||
| predictWithImpurity = Some(predictWithImpurity.getOrElse( | ||
| calculatePredictImpurity(leftChildStats, rightChildStats))) | ||
| val gainStats = calculateGainForSplit(leftChildStats, | ||
| rightChildStats, binAggregates.metadata) | ||
| rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) | ||
| (splitIdx, gainStats) | ||
| }.maxBy(_._2.gain) | ||
| (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) | ||
|
|
@@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging { | |
| Range(0, numSplits).map { splitIndex => | ||
| val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) | ||
| val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) | ||
| predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) | ||
| predictWithImpurity = Some(predictWithImpurity.getOrElse( | ||
| calculatePredictImpurity(leftChildStats, rightChildStats))) | ||
| val gainStats = calculateGainForSplit(leftChildStats, | ||
| rightChildStats, binAggregates.metadata) | ||
| rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) | ||
| (splitIndex, gainStats) | ||
| }.maxBy(_._2.gain) | ||
| (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) | ||
|
|
@@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging { | |
| val rightChildStats = | ||
| binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) | ||
| rightChildStats.subtract(leftChildStats) | ||
| predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) | ||
| predictWithImpurity = Some(predictWithImpurity.getOrElse( | ||
| calculatePredictImpurity(leftChildStats, rightChildStats))) | ||
| val gainStats = calculateGainForSplit(leftChildStats, | ||
| rightChildStats, binAggregates.metadata) | ||
| rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) | ||
| (splitIndex, gainStats) | ||
| }.maxBy(_._2.gain) | ||
| val categoriesForSplit = | ||
|
|
@@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging { | |
| } | ||
| }.maxBy(_._2.gain) | ||
|
|
||
| assert(predict.isDefined, "must calculate predict for each node") | ||
|
|
||
| (bestSplit, bestSplitStats, predict.get) | ||
| (bestSplit, bestSplitStats, predictWithImpurity.get._1) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.linalg.Vector | |
| * | ||
| * @param id integer node id, from 1 | ||
| * @param predict predicted value at the node | ||
| * @param impurity current node impurity | ||
| * @param isLeaf whether the leaf is a node | ||
| * @param split split to calculate left and right nodes | ||
| * @param leftNode left child | ||
|
|
@@ -41,15 +42,16 @@ import org.apache.spark.mllib.linalg.Vector | |
| @DeveloperApi | ||
| class Node ( | ||
| val id: Int, | ||
| var predict: Double, | ||
| var predict: Predict, | ||
| var impurity: Double, | ||
| var isLeaf: Boolean, | ||
| var split: Option[Split], | ||
| var leftNode: Option[Node], | ||
| var rightNode: Option[Node], | ||
| var stats: Option[InformationGainStats]) extends Serializable with Logging { | ||
|
|
||
| override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + | ||
| "split = " + split + ", stats = " + stats | ||
| "impurity = " + impurity + "split = " + split + ", stats = " + stats | ||
|
|
||
| /** | ||
| * build the left node and right nodes if not leaf | ||
|
|
@@ -62,6 +64,7 @@ class Node ( | |
| logDebug("id = " + id + ", split = " + split) | ||
| logDebug("stats = " + stats) | ||
| logDebug("predict = " + predict) | ||
| logDebug("impurity = " + impurity) | ||
| if (!isLeaf) { | ||
| leftNode = Some(nodes(Node.leftChildIndex(id))) | ||
| rightNode = Some(nodes(Node.rightChildIndex(id))) | ||
|
|
@@ -77,7 +80,7 @@ class Node ( | |
| */ | ||
| def predict(features: Vector) : Double = { | ||
| if (isLeaf) { | ||
| predict | ||
| predict.predict | ||
| } else{ | ||
| if (split.get.featureType == Continuous) { | ||
| if (features(split.get.feature) <= split.get.threshold) { | ||
|
|
@@ -109,7 +112,7 @@ class Node ( | |
| } else { | ||
| Some(rightNode.get.deepCopy()) | ||
| } | ||
| new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) | ||
| new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -154,7 +157,7 @@ class Node ( | |
| } | ||
| val prefix: String = " " * indentFactor | ||
| if (isLeaf) { | ||
| prefix + s"Predict: $predict\n" | ||
| prefix + s"Predict: ${predict.predict}\n" | ||
| } else { | ||
| prefix + s"If ${splitToString(split.get, left=true)}\n" + | ||
| leftNode.get.subtreeToString(indentFactor + 1) + | ||
|
|
@@ -170,7 +173,27 @@ private[tree] object Node { | |
| /** | ||
| * Return a node with the given node id (but nothing else set). | ||
| */ | ||
| def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) | ||
| def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, | ||
| false, None, None, None, None) | ||
|
|
||
| /** | ||
| * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. | ||
| * This is used in `DecisionTree.findBestSplits` to construct child nodes | ||
| * after find best splits for each node. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "after find best splits for each node" --> "after finding the best splits for parent nodes" |
||
| * Other fields are set at next level. | ||
| * @param nodeIndex integer node id, from 1 | ||
| * @param predict predicted value at the node | ||
| * @param impurity current node impurity | ||
| * @param isLeaf whether the leaf is a node | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "whether the leaf is a node" --> "whether the node is a leaf" |
||
| * @return newed node instance | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "newed" --> "new" |
||
| */ | ||
| def apply( | ||
| nodeIndex: Int, | ||
| predict: Predict, | ||
| impurity: Double, | ||
| isLeaf: Boolean): Node = { | ||
| new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) | ||
| } | ||
|
|
||
| /** | ||
| * Return the index of the left child of this node. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be shorter:
(no need for 2 wrapped nodes declarations)