Skip to content

Commit 14f222f

Browse files
chouqinmengxr
authored andcommitted
[SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training
Currently, the implementation does one unnecessary aggregation step. The aggregation step for level L (to choose splits) gives enough information to set the predictions of any leaf nodes at level L+1. We can use that info and skip the aggregation step for the last level of the tree (which only has leaf nodes). ### Implementation Details Each node now has a `impurity` field and the `predict` is changed from type `Double` to type `Predict`(this can be used to compute predict probability in the future) When compute best splits for each node, we also compute impurity and predict for the child nodes, which is used to constructed newly allocated child nodes. So at level L, we have set impurity and predict for nodes at level L +1. If level L+1 is the last level, then we can avoid aggregation. What's more, calculation of parent impurity in Top nodes for each tree needs to be treated differently because we have to compute impurity and predict for them first. In `binsToBestSplit`, if current node is top node(level == 0), we calculate impurity and predict first. after finding best split, top node's predict and impurity is set to the calculated value. Non-top nodes's impurity and predict are already calculated and don't need to be recalculated again. I have considered to add a initialization step to set top nodes' impurity and predict and then we can treat all nodes in the same way, but this will need a lot of duplication of code(all the code to do seq operation(BinSeqOp) needs to be duplicated), so I choose the current way. CC mengxr manishamde jkbradley, please help me review this, thanks. Author: Qiping Li <[email protected]> Closes apache#2708 from chouqin/avoid-agg and squashes the following commits: 8e269ea [Qiping Li] adjust code and comments eefeef1 [Qiping Li] adjust comments and check child nodes' impurity c41b1b6 [Qiping Li] fix pyspark unit test 7ad7a71 [Qiping Li] fix unit test 822c912 [Qiping Li] add comments and unit test e41d715 [Qiping Li] fix bug in test suite 6cc0333 [Qiping Li] SPARK-3158: Avoid 1 extra aggregation for DecisionTree training
1 parent 13cab5b commit 14f222f

File tree

4 files changed

+197
-48
lines changed

4 files changed

+197
-48
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,14 @@ object DecisionTree extends Serializable with Logging {
532532
Some(mutableNodeToFeatures.toMap)
533533
}
534534

535+
// array of nodes to train indexed by node index in group
536+
val nodes = new Array[Node](numNodes)
537+
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
538+
nodesForTree.foreach { node =>
539+
nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
540+
}
541+
}
542+
535543
// Calculate best splits for all nodes in the group
536544
timer.start("chooseSplits")
537545

@@ -568,7 +576,7 @@ object DecisionTree extends Serializable with Logging {
568576

569577
// find best split for each node
570578
val (split: Split, stats: InformationGainStats, predict: Predict) =
571-
binsToBestSplit(aggStats, splits, featuresForNode)
579+
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
572580
(nodeIndex, (split, stats, predict))
573581
}.collectAsMap()
574582

@@ -587,17 +595,30 @@ object DecisionTree extends Serializable with Logging {
587595
// Extract info for this node. Create children if not leaf.
588596
val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth)
589597
assert(node.id == nodeIndex)
590-
node.predict = predict.predict
598+
node.predict = predict
591599
node.isLeaf = isLeaf
592600
node.stats = Some(stats)
601+
node.impurity = stats.impurity
593602
logDebug("Node = " + node)
594603

595604
if (!isLeaf) {
596605
node.split = Some(split)
597-
node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex)))
598-
node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex)))
599-
nodeQueue.enqueue((treeIndex, node.leftNode.get))
600-
nodeQueue.enqueue((treeIndex, node.rightNode.get))
606+
val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
607+
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
608+
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
609+
node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex),
610+
stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
611+
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
612+
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
613+
614+
// enqueue left child and right child if they are not leaves
615+
if (!leftChildIsLeaf) {
616+
nodeQueue.enqueue((treeIndex, node.leftNode.get))
617+
}
618+
if (!rightChildIsLeaf) {
619+
nodeQueue.enqueue((treeIndex, node.rightNode.get))
620+
}
621+
601622
logDebug("leftChildIndex = " + node.leftNode.get.id +
602623
", impurity = " + stats.leftImpurity)
603624
logDebug("rightChildIndex = " + node.rightNode.get.id +
@@ -617,7 +638,8 @@ object DecisionTree extends Serializable with Logging {
617638
private def calculateGainForSplit(
618639
leftImpurityCalculator: ImpurityCalculator,
619640
rightImpurityCalculator: ImpurityCalculator,
620-
metadata: DecisionTreeMetadata): InformationGainStats = {
641+
metadata: DecisionTreeMetadata,
642+
impurity: Double): InformationGainStats = {
621643
val leftCount = leftImpurityCalculator.count
622644
val rightCount = rightImpurityCalculator.count
623645

@@ -630,11 +652,6 @@ object DecisionTree extends Serializable with Logging {
630652

631653
val totalCount = leftCount + rightCount
632654

633-
val parentNodeAgg = leftImpurityCalculator.copy
634-
parentNodeAgg.add(rightImpurityCalculator)
635-
636-
val impurity = parentNodeAgg.calculate()
637-
638655
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
639656
val rightImpurity = rightImpurityCalculator.calculate()
640657

@@ -649,25 +666,36 @@ object DecisionTree extends Serializable with Logging {
649666
return InformationGainStats.invalidInformationGainStats
650667
}
651668

652-
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
669+
// calculate left and right predict
670+
val leftPredict = calculatePredict(leftImpurityCalculator)
671+
val rightPredict = calculatePredict(rightImpurityCalculator)
672+
673+
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
674+
leftPredict, rightPredict)
675+
}
676+
677+
private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
678+
val predict = impurityCalculator.predict
679+
val prob = impurityCalculator.prob(predict)
680+
new Predict(predict, prob)
653681
}
654682

655683
/**
656684
* Calculate predict value for current node, given stats of any split.
657685
* Note that this function is called only once for each node.
658686
* @param leftImpurityCalculator left node aggregates for a split
659687
* @param rightImpurityCalculator right node aggregates for a split
660-
* @return predict value for current node
688+
* @return predict value and impurity for current node
661689
*/
662-
private def calculatePredict(
690+
private def calculatePredictImpurity(
663691
leftImpurityCalculator: ImpurityCalculator,
664-
rightImpurityCalculator: ImpurityCalculator): Predict = {
692+
rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
665693
val parentNodeAgg = leftImpurityCalculator.copy
666694
parentNodeAgg.add(rightImpurityCalculator)
667-
val predict = parentNodeAgg.predict
668-
val prob = parentNodeAgg.prob(predict)
695+
val predict = calculatePredict(parentNodeAgg)
696+
val impurity = parentNodeAgg.calculate()
669697

670-
new Predict(predict, prob)
698+
(predict, impurity)
671699
}
672700

673701
/**
@@ -678,10 +706,16 @@ object DecisionTree extends Serializable with Logging {
678706
private def binsToBestSplit(
679707
binAggregates: DTStatsAggregator,
680708
splits: Array[Array[Split]],
681-
featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = {
709+
featuresForNode: Option[Array[Int]],
710+
node: Node): (Split, InformationGainStats, Predict) = {
682711

683-
// calculate predict only once
684-
var predict: Option[Predict] = None
712+
// calculate predict and impurity if current node is top node
713+
val level = Node.indexToLevel(node.id)
714+
var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
715+
None
716+
} else {
717+
Some((node.predict, node.impurity))
718+
}
685719

686720
// For each (feature, split), calculate the gain, and select the best (feature, split).
687721
val (bestSplit, bestSplitStats) =
@@ -708,9 +742,10 @@ object DecisionTree extends Serializable with Logging {
708742
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
709743
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
710744
rightChildStats.subtract(leftChildStats)
711-
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
745+
predictWithImpurity = Some(predictWithImpurity.getOrElse(
746+
calculatePredictImpurity(leftChildStats, rightChildStats)))
712747
val gainStats = calculateGainForSplit(leftChildStats,
713-
rightChildStats, binAggregates.metadata)
748+
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
714749
(splitIdx, gainStats)
715750
}.maxBy(_._2.gain)
716751
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -722,9 +757,10 @@ object DecisionTree extends Serializable with Logging {
722757
Range(0, numSplits).map { splitIndex =>
723758
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
724759
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
725-
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
760+
predictWithImpurity = Some(predictWithImpurity.getOrElse(
761+
calculatePredictImpurity(leftChildStats, rightChildStats)))
726762
val gainStats = calculateGainForSplit(leftChildStats,
727-
rightChildStats, binAggregates.metadata)
763+
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
728764
(splitIndex, gainStats)
729765
}.maxBy(_._2.gain)
730766
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -794,9 +830,10 @@ object DecisionTree extends Serializable with Logging {
794830
val rightChildStats =
795831
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
796832
rightChildStats.subtract(leftChildStats)
797-
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
833+
predictWithImpurity = Some(predictWithImpurity.getOrElse(
834+
calculatePredictImpurity(leftChildStats, rightChildStats)))
798835
val gainStats = calculateGainForSplit(leftChildStats,
799-
rightChildStats, binAggregates.metadata)
836+
rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
800837
(splitIndex, gainStats)
801838
}.maxBy(_._2.gain)
802839
val categoriesForSplit =
@@ -807,9 +844,7 @@ object DecisionTree extends Serializable with Logging {
807844
}
808845
}.maxBy(_._2.gain)
809846

810-
assert(predict.isDefined, "must calculate predict for each node")
811-
812-
(bestSplit, bestSplitStats, predict.get)
847+
(bestSplit, bestSplitStats, predictWithImpurity.get._1)
813848
}
814849

815850
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi
2626
* @param impurity current node impurity
2727
* @param leftImpurity left node impurity
2828
* @param rightImpurity right node impurity
29+
* @param leftPredict left node predict
30+
* @param rightPredict right node predict
2931
*/
3032
@DeveloperApi
3133
class InformationGainStats(
3234
val gain: Double,
3335
val impurity: Double,
3436
val leftImpurity: Double,
35-
val rightImpurity: Double) extends Serializable {
37+
val rightImpurity: Double,
38+
val leftPredict: Predict,
39+
val rightPredict: Predict) extends Serializable {
3640

3741
override def toString = {
3842
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
@@ -58,5 +62,6 @@ private[tree] object InformationGainStats {
5862
* denote that current split doesn't satisfies minimum info gain or
5963
* minimum number of instances per node.
6064
*/
61-
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
65+
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
66+
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
6267
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector
3232
*
3333
* @param id integer node id, from 1
3434
* @param predict predicted value at the node
35-
* @param isLeaf whether the leaf is a node
35+
* @param impurity current node impurity
36+
* @param isLeaf whether the node is a leaf
3637
* @param split split to calculate left and right nodes
3738
* @param leftNode left child
3839
* @param rightNode right child
@@ -41,15 +42,16 @@ import org.apache.spark.mllib.linalg.Vector
4142
@DeveloperApi
4243
class Node (
4344
val id: Int,
44-
var predict: Double,
45+
var predict: Predict,
46+
var impurity: Double,
4547
var isLeaf: Boolean,
4648
var split: Option[Split],
4749
var leftNode: Option[Node],
4850
var rightNode: Option[Node],
4951
var stats: Option[InformationGainStats]) extends Serializable with Logging {
5052

5153
override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
52-
"split = " + split + ", stats = " + stats
54+
"impurity = " + impurity + "split = " + split + ", stats = " + stats
5355

5456
/**
5557
* build the left node and right nodes if not leaf
@@ -62,6 +64,7 @@ class Node (
6264
logDebug("id = " + id + ", split = " + split)
6365
logDebug("stats = " + stats)
6466
logDebug("predict = " + predict)
67+
logDebug("impurity = " + impurity)
6568
if (!isLeaf) {
6669
leftNode = Some(nodes(Node.leftChildIndex(id)))
6770
rightNode = Some(nodes(Node.rightChildIndex(id)))
@@ -77,7 +80,7 @@ class Node (
7780
*/
7881
def predict(features: Vector) : Double = {
7982
if (isLeaf) {
80-
predict
83+
predict.predict
8184
} else{
8285
if (split.get.featureType == Continuous) {
8386
if (features(split.get.feature) <= split.get.threshold) {
@@ -109,7 +112,7 @@ class Node (
109112
} else {
110113
Some(rightNode.get.deepCopy())
111114
}
112-
new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
115+
new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
113116
}
114117

115118
/**
@@ -154,7 +157,7 @@ class Node (
154157
}
155158
val prefix: String = " " * indentFactor
156159
if (isLeaf) {
157-
prefix + s"Predict: $predict\n"
160+
prefix + s"Predict: ${predict.predict}\n"
158161
} else {
159162
prefix + s"If ${splitToString(split.get, left=true)}\n" +
160163
leftNode.get.subtreeToString(indentFactor + 1) +
@@ -170,7 +173,27 @@ private[tree] object Node {
170173
/**
171174
* Return a node with the given node id (but nothing else set).
172175
*/
173-
def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None)
176+
def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0,
177+
false, None, None, None, None)
178+
179+
/**
180+
* Construct a node with nodeIndex, predict, impurity and isLeaf parameters.
181+
* This is used in `DecisionTree.findBestSplits` to construct child nodes
182+
* after finding the best splits for parent nodes.
183+
* Other fields are set at next level.
184+
* @param nodeIndex integer node id, from 1
185+
* @param predict predicted value at the node
186+
* @param impurity current node impurity
187+
* @param isLeaf whether the node is a leaf
188+
* @return new node instance
189+
*/
190+
def apply(
191+
nodeIndex: Int,
192+
predict: Predict,
193+
impurity: Double,
194+
isLeaf: Boolean): Node = {
195+
new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None)
196+
}
174197

175198
/**
176199
* Return the index of the left child of this node.

0 commit comments

Comments
 (0)