-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training #2708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
6cc0333
e41d715
822c912
7ad7a71
c41b1b6
eefeef1
8e269ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging { | |
| Some(mutableNodeToFeatures.toMap) | ||
| } | ||
|
|
||
| // array of nodes to train indexed by node index in group | ||
| val nodes = Array.fill[Node](numNodes)(null) | ||
| nodesForGroup.foreach { case (treeIndex, nodesForTree) => | ||
| nodesForTree.foreach { node => | ||
| nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node | ||
| } | ||
| } | ||
|
|
||
| // Calculate best splits for all nodes in the group | ||
| timer.start("chooseSplits") | ||
|
|
||
|
|
@@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging { | |
|
|
||
| // find best split for each node | ||
| val (split: Split, stats: InformationGainStats, predict: Predict) = | ||
| binsToBestSplit(aggStats, splits, featuresForNode) | ||
| binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) | ||
| (nodeIndex, (split, stats, predict)) | ||
| }.collectAsMap() | ||
|
|
||
|
|
@@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging { | |
| // Extract info for this node. Create children if not leaf. | ||
| val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) | ||
| assert(node.id == nodeIndex) | ||
| node.predict = predict.predict | ||
| node.predict = predict | ||
| node.isLeaf = isLeaf | ||
| node.stats = Some(stats) | ||
| node.impurity = stats.impurity | ||
| logDebug("Node = " + node) | ||
|
|
||
| if (!isLeaf) { | ||
| node.split = Some(split) | ||
| node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) | ||
| node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) | ||
| nodeQueue.enqueue((treeIndex, node.leftNode.get)) | ||
| nodeQueue.enqueue((treeIndex, node.rightNode.get)) | ||
| val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can also check stats.leftImpurity and rightImpurity. If stats.leftImpurity = 0, then we know the left child will be a leaf. Same for the right child.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once this is done, it might be good to add 1 more test to make sure it works. A slight modification of the test you already added should work. |
||
| val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) | ||
| val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) | ||
| node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), | ||
| stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) | ||
| node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), | ||
| stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) | ||
|
|
||
| // enqueue left child and right child if they are not leaves | ||
| if (!leftChildIsLeaf) { | ||
| nodeQueue.enqueue((treeIndex, node.leftNode.get)) | ||
| } | ||
| if (!rightChildIsLeaf) { | ||
| nodeQueue.enqueue((treeIndex, node.rightNode.get)) | ||
| } | ||
|
|
||
| logDebug("leftChildIndex = " + node.leftNode.get.id + | ||
| ", impurity = " + stats.leftImpurity) | ||
| logDebug("rightChildIndex = " + node.rightNode.get.id + | ||
|
|
@@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging { | |
| private def calculateGainForSplit( | ||
| leftImpurityCalculator: ImpurityCalculator, | ||
| rightImpurityCalculator: ImpurityCalculator, | ||
| metadata: DecisionTreeMetadata): InformationGainStats = { | ||
| metadata: DecisionTreeMetadata, | ||
| impurity: Double): InformationGainStats = { | ||
| val leftCount = leftImpurityCalculator.count | ||
| val rightCount = rightImpurityCalculator.count | ||
|
|
||
|
|
@@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging { | |
|
|
||
| val totalCount = leftCount + rightCount | ||
|
|
||
| val parentNodeAgg = leftImpurityCalculator.copy | ||
| parentNodeAgg.add(rightImpurityCalculator) | ||
|
|
||
| val impurity = parentNodeAgg.calculate() | ||
|
|
||
| val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 | ||
| val rightImpurity = rightImpurityCalculator.calculate() | ||
|
|
||
|
|
@@ -649,7 +666,18 @@ object DecisionTree extends Serializable with Logging { | |
| return InformationGainStats.invalidInformationGainStats | ||
| } | ||
|
|
||
| new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) | ||
| // calculate left and right predict | ||
| val leftPredict = calculatePredict(leftImpurityCalculator) | ||
| val rightPredict = calculatePredict(rightImpurityCalculator) | ||
|
|
||
| new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, | ||
| leftPredict, rightPredict) | ||
| } | ||
|
|
||
| private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { | ||
| val predict = impurityCalculator.predict | ||
| val prob = impurityCalculator.prob(predict) | ||
| new Predict(predict, prob) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -659,15 +687,15 @@ object DecisionTree extends Serializable with Logging { | |
| * @param rightImpurityCalculator right node aggregates for a split | ||
| * @return predict value for current node | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| */ | ||
| private def calculatePredict( | ||
| private def calculatePredictImpurity( | ||
| leftImpurityCalculator: ImpurityCalculator, | ||
| rightImpurityCalculator: ImpurityCalculator): Predict = { | ||
| rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { | ||
| val parentNodeAgg = leftImpurityCalculator.copy | ||
| parentNodeAgg.add(rightImpurityCalculator) | ||
| val predict = parentNodeAgg.predict | ||
| val prob = parentNodeAgg.prob(predict) | ||
| val predict = calculatePredict(parentNodeAgg) | ||
| val impurity = parentNodeAgg.calculate() | ||
|
|
||
| new Predict(predict, prob) | ||
| (predict, impurity) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging { | |
| private def binsToBestSplit( | ||
| binAggregates: DTStatsAggregator, | ||
| splits: Array[Array[Split]], | ||
| featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { | ||
| featuresForNode: Option[Array[Int]], | ||
| node: Node): (Split, InformationGainStats, Predict) = { | ||
|
|
||
| // calculate predict only once | ||
| var predict: Option[Predict] = None | ||
| // calculate predict and impurity if current node is top node | ||
| val level = Node.indexToLevel(node.id) | ||
| var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { | ||
| None | ||
| } else { | ||
| Some((node.predict, node.impurity)) | ||
| } | ||
|
|
||
| // For each (feature, split), calculate the gain, and select the best (feature, split). | ||
| val (bestSplit, bestSplitStats) = | ||
|
|
@@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging { | |
| val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) | ||
| val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) | ||
| rightChildStats.subtract(leftChildStats) | ||
| predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) | ||
| predictWithImpurity = Some(predictWithImpurity.getOrElse( | ||
| calculatePredictImpurity(leftChildStats, rightChildStats))) | ||
| val gainStats = calculateGainForSplit(leftChildStats, | ||
| rightChildStats, binAggregates.metadata) | ||
| rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) | ||
| (splitIdx, gainStats) | ||
| }.maxBy(_._2.gain) | ||
| (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) | ||
|
|
@@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging { | |
| Range(0, numSplits).map { splitIndex => | ||
| val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) | ||
| val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) | ||
| predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) | ||
| predictWithImpurity = Some(predictWithImpurity.getOrElse( | ||
| calculatePredictImpurity(leftChildStats, rightChildStats))) | ||
| val gainStats = calculateGainForSplit(leftChildStats, | ||
| rightChildStats, binAggregates.metadata) | ||
| rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) | ||
| (splitIndex, gainStats) | ||
| }.maxBy(_._2.gain) | ||
| (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) | ||
|
|
@@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging { | |
| val rightChildStats = | ||
| binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) | ||
| rightChildStats.subtract(leftChildStats) | ||
| predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) | ||
| predictWithImpurity = Some(predictWithImpurity.getOrElse( | ||
| calculatePredictImpurity(leftChildStats, rightChildStats))) | ||
| val gainStats = calculateGainForSplit(leftChildStats, | ||
| rightChildStats, binAggregates.metadata) | ||
| rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) | ||
| (splitIndex, gainStats) | ||
| }.maxBy(_._2.gain) | ||
| val categoriesForSplit = | ||
|
|
@@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging { | |
| } | ||
| }.maxBy(_._2.gain) | ||
|
|
||
| assert(predict.isDefined, "must calculate predict for each node") | ||
|
|
||
| (bestSplit, bestSplitStats, predict.get) | ||
| (bestSplit, bestSplitStats, predictWithImpurity.get._1) | ||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
val nodes = new Array[Node](numNodes)