Skip to content

Commit 2ab763b

Browse files
committed
Simplifications to DecisionTree code:
No longer pre-allocate parentImpurities array in main train() method. * parentImpurities values are now stored in individual nodes (in Node.stats.impurity). No longer using Node.build since tree structure is constructed on-the-fly. * Did not eliminate since it is public (Developer) API. Also: Updated DecisionTreeSuite test "Second level node building with vs. without groups" * generateOrderedLabeledPoints() modified so that it really does require 2 levels of internal nodes.
1 parent 7db5339 commit 2ab763b

File tree

2 files changed

+37
-70
lines changed

2 files changed

+37
-70
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 14 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
8989
s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
9090
// Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
9191
val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
92-
// Initialize an array to hold parent impurity calculations for each node.
93-
val parentImpurities = new Array[Double](maxNumNodesPlus1)
9492
// dummy value for top node (updated during first split calculation)
9593
val nodes = new Array[Node](maxNumNodesPlus1)
9694

@@ -131,7 +129,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
131129
// Find best split for all nodes at a level.
132130
timer.start("findBestSplits")
133131
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
134-
DecisionTree.findBestSplits(treeInput, parentImpurities,
132+
DecisionTree.findBestSplits(treeInput,
135133
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
136134
timer.stop("findBestSplits")
137135

@@ -158,20 +156,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
158156
nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
159157
}
160158
}
161-
// Extract info for nodes at the next lower level.
162-
timer.start("extractInfoForLowerLevels")
163159
if (level < maxDepth) {
164-
val leftChildIndex = Node.leftChildIndex(nodeIndex)
165-
val leftImpurity = stats.leftImpurity
166-
logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity)
167-
parentImpurities(leftChildIndex) = leftImpurity
168-
169-
val rightChildIndex = Node.rightChildIndex(nodeIndex)
170-
val rightImpurity = stats.rightImpurity
171-
logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity)
172-
parentImpurities(rightChildIndex) = rightImpurity
160+
logDebug("leftChildIndex = " + Node.leftChildIndex(nodeIndex) +
161+
", impurity = " + stats.leftImpurity)
162+
logDebug("rightChildIndex = " + Node.rightChildIndex(nodeIndex) +
163+
", impurity = " + stats.rightImpurity)
173164
}
174-
timer.stop("extractInfoForLowerLevels")
175165
logDebug("final best split = " + split)
176166
}
177167
require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
@@ -189,17 +179,12 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
189179
logDebug("Extracting tree model")
190180
logDebug("#####################################")
191181

192-
// Initialize the top or root node of the tree.
193-
val topNode = nodes(1)
194-
// Build the full tree using the node info calculated in the level-wise best split calculations.
195-
topNode.build(nodes)
196-
197182
timer.stop("total")
198183

199184
logInfo("Internal timing for DecisionTree:")
200185
logInfo(s"$timer")
201186

202-
new DecisionTreeModel(topNode, strategy.algo)
187+
new DecisionTreeModel(nodes(1), strategy.algo)
203188
}
204189

205190
}
@@ -408,7 +393,6 @@ object DecisionTree extends Serializable with Logging {
408393
* multiple groups if the level-wise training task could lead to memory overflow.
409394
*
410395
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
411-
* @param parentImpurities Impurities for all parent nodes for the current level
412396
* @param metadata Learning and dataset metadata
413397
* @param level Level of the tree
414398
* @param splits possible splits for all features, indexed (numFeatures)(numSplits)
@@ -418,7 +402,6 @@ object DecisionTree extends Serializable with Logging {
418402
*/
419403
private[tree] def findBestSplits(
420404
input: RDD[TreePoint],
421-
parentImpurities: Array[Double],
422405
metadata: DecisionTreeMetadata,
423406
level: Int,
424407
nodes: Array[Node],
@@ -438,14 +421,14 @@ object DecisionTree extends Serializable with Logging {
438421
// Iterate over each group of nodes at a level.
439422
var groupIndex = 0
440423
while (groupIndex < numGroups) {
441-
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level,
424+
val bestSplitsForGroup = findBestSplitsPerGroup(input, metadata, level,
442425
nodes, splits, bins, timer, numGroups, groupIndex)
443426
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
444427
groupIndex += 1
445428
}
446429
bestSplits
447430
} else {
448-
findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer)
431+
findBestSplitsPerGroup(input, metadata, level, nodes, splits, bins, timer)
449432
}
450433
}
451434

@@ -585,7 +568,6 @@ object DecisionTree extends Serializable with Logging {
585568
* Returns an array of optimal splits for a group of nodes at a given level
586569
*
587570
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
588-
* @param parentImpurities Impurities for all parent nodes for the current level
589571
* @param metadata Learning and dataset metadata
590572
* @param level Level of the tree
591573
* @param nodes Array of all nodes in the tree. Used for matching data points to nodes.
@@ -597,7 +579,6 @@ object DecisionTree extends Serializable with Logging {
597579
*/
598580
private def findBestSplitsPerGroup(
599581
input: RDD[TreePoint],
600-
parentImpurities: Array[Double],
601582
metadata: DecisionTreeMetadata,
602583
level: Int,
603584
nodes: Array[Node],
@@ -709,10 +690,8 @@ object DecisionTree extends Serializable with Logging {
709690
// Iterating over all nodes at this level
710691
var nodeIndex = 0
711692
while (nodeIndex < numNodes) {
712-
val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
713-
logDebug("node impurity = " + nodeImpurity)
714693
bestSplits(nodeIndex) =
715-
binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
694+
binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
716695
logDebug("best split = " + bestSplits(nodeIndex)._1)
717696
nodeIndex += 1
718697
}
@@ -725,13 +704,11 @@ object DecisionTree extends Serializable with Logging {
725704
* Calculate the information gain for a given (feature, split) based upon left/right aggregates.
726705
* @param leftImpurityCalculator left node aggregates for this (feature, split)
727706
* @param rightImpurityCalculator right node aggregate for this (feature, split)
728-
* @param topImpurity impurity of the parent node
729707
* @return information gain and statistics for all splits
730708
*/
731709
private def calculateGainForSplit(
732710
leftImpurityCalculator: ImpurityCalculator,
733711
rightImpurityCalculator: ImpurityCalculator,
734-
topImpurity: Double,
735712
level: Int,
736713
metadata: DecisionTreeMetadata): InformationGainStats = {
737714

@@ -741,18 +718,13 @@ object DecisionTree extends Serializable with Logging {
741718
val totalCount = leftCount + rightCount
742719
if (totalCount == 0) {
743720
// Return arbitrary prediction.
744-
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
721+
return new InformationGainStats(0, 0, 0, 0, 0)
745722
}
746723

747724
val parentNodeAgg = leftImpurityCalculator.copy
748725
parentNodeAgg.add(rightImpurityCalculator)
749-
// impurity of parent node
750-
val impurity = if (level > 0) {
751-
topImpurity
752-
} else {
753-
parentNodeAgg.calculate()
754-
}
755726

727+
val impurity = parentNodeAgg.calculate()
756728
val predict = parentNodeAgg.predict
757729
val prob = parentNodeAgg.prob(predict)
758730

@@ -771,19 +743,15 @@ object DecisionTree extends Serializable with Logging {
771743
* Find the best split for a node.
772744
* @param binAggregates Bin statistics.
773745
* @param nodeIndex Index for node to split in this (level, group).
774-
* @param nodeImpurity Impurity of the node (nodeIndex).
775746
* @return tuple for best split: (Split, information gain)
776747
*/
777748
private def binsToBestSplit(
778749
binAggregates: DTStatsAggregator,
779750
nodeIndex: Int,
780-
nodeImpurity: Double,
781751
level: Int,
782752
metadata: DecisionTreeMetadata,
783753
splits: Array[Array[Split]]): (Split, InformationGainStats) = {
784754

785-
logDebug("node impurity = " + nodeImpurity)
786-
787755
// For each (feature, split), calculate the gain, and select the best (feature, split).
788756
Range(0, metadata.numFeatures).map { featureIndex =>
789757
val numSplits = metadata.numSplits(featureIndex)
@@ -803,8 +771,7 @@ object DecisionTree extends Serializable with Logging {
803771
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
804772
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
805773
rightChildStats.subtract(leftChildStats)
806-
val gainStats =
807-
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
774+
val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
808775
(splitIdx, gainStats)
809776
}.maxBy(_._2.gain)
810777
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -816,8 +783,7 @@ object DecisionTree extends Serializable with Logging {
816783
Range(0, numSplits).map { splitIndex =>
817784
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
818785
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
819-
val gainStats =
820-
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
786+
val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
821787
(splitIndex, gainStats)
822788
}.maxBy(_._2.gain)
823789
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -887,8 +853,7 @@ object DecisionTree extends Serializable with Logging {
887853
val rightChildStats =
888854
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
889855
rightChildStats.subtract(leftChildStats)
890-
val gainStats =
891-
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
856+
val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
892857
(splitIndex, gainStats)
893858
}.maxBy(_._2.gain)
894859
val categoriesForSplit =

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
271271
assert(bins(0).length === 0)
272272

273273
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
274-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
274+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
275275
new Array[Node](0), splits, bins, 10)
276276

277277
val split = bestSplits(0)._1
@@ -303,7 +303,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
303303

304304
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
305305
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
306-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
306+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
307307
new Array[Node](0), splits, bins, 10)
308308

309309
val split = bestSplits(0)._1
@@ -357,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
357357
assert(bins(0).length === 100)
358358

359359
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
360-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
360+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
361361
new Array[Node](0), splits, bins, 10)
362362
assert(bestSplits.length === 1)
363363
assert(bestSplits(0)._1.feature === 0)
@@ -385,7 +385,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
385385
assert(bins(0).length === 100)
386386

387387
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
388-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
388+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
389389
new Array[Node](0), splits, bins, 10)
390390
assert(bestSplits.length === 1)
391391
assert(bestSplits(0)._1.feature === 0)
@@ -414,7 +414,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
414414
assert(bins(0).length === 100)
415415

416416
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
417-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
417+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
418418
new Array[Node](0), splits, bins, 10)
419419
assert(bestSplits.length === 1)
420420
assert(bestSplits(0)._1.feature === 0)
@@ -443,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
443443
assert(bins(0).length === 100)
444444

445445
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
446-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
446+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
447447
new Array[Node](0), splits, bins, 10)
448448
assert(bestSplits.length === 1)
449449
assert(bestSplits(0)._1.feature === 0)
@@ -468,26 +468,25 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
468468
assert(bins(0).length === 100)
469469

470470
// Train a 1-node model
471-
val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
471+
val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
472+
numClassesForClassification = 2, maxBins = 100)
472473
val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
473474
val nodes: Array[Node] = new Array[Node](8)
474475
nodes(1) = modelOneNode.topNode
475476
nodes(1).leftNode = None
476477
nodes(1).rightNode = None
477478

478-
val parentImpurities = Array(0, 0.5, 0.5, 0.5)
479-
480479
// Single group second level tree construction.
481480
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
482-
val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes,
481+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 1, nodes,
483482
splits, bins, 10)
484483
assert(bestSplits.length === 2)
485484
assert(bestSplits(0)._2.gain > 0)
486485
assert(bestSplits(1)._2.gain > 0)
487486

488487
// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
489488
// level tree construction.
490-
val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1,
489+
val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, metadata, 1,
491490
nodes, splits, bins, 0)
492491
assert(bestSplitsWithGroups.length === 2)
493492
assert(bestSplitsWithGroups(0)._2.gain > 0)
@@ -517,7 +516,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
517516

518517
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
519518
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
520-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
519+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
521520
new Array[Node](0), splits, bins, 10)
522521

523522
assert(bestSplits.length === 1)
@@ -582,7 +581,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
582581

583582
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
584583
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
585-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
584+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
586585
new Array[Node](0), splits, bins, 10)
587586

588587
assert(bestSplits.length === 1)
@@ -609,7 +608,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
609608

610609
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
611610
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
612-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
611+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
613612
new Array[Node](0), splits, bins, 10)
614613

615614
assert(bestSplits.length === 1)
@@ -636,7 +635,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
636635

637636
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
638637
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
639-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
638+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
640639
new Array[Node](0), splits, bins, 10)
641640

642641
assert(bestSplits.length === 1)
@@ -660,7 +659,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
660659

661660
val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
662661
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
663-
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
662+
val bestSplits = DecisionTree.findBestSplits(treeInput, metadata, 0,
664663
new Array[Node](0), splits, bins, 10)
665664

666665
assert(bestSplits.length === 1)
@@ -709,13 +708,16 @@ object DecisionTreeSuite {
709708
def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
710709
val arr = new Array[LabeledPoint](1000)
711710
for (i <- 0 until 1000) {
712-
if (i < 600) {
713-
val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
714-
arr(i) = lp
711+
val label = if (i < 100) {
712+
0.0
713+
} else if (i < 500) {
714+
1.0
715+
} else if (i < 900) {
716+
0.0
715717
} else {
716-
val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
717-
arr(i) = lp
718+
1.0
718719
}
720+
arr(i) = new LabeledPoint(label, Vectors.dense(i.toDouble, 1000.0 - i))
719721
}
720722
arr
721723
}

0 commit comments

Comments
 (0)