1717
1818package org .apache .spark .mllib .tree
1919
20- import java .util .Calendar
2120
2221import scala .collection .JavaConverters ._
2322
@@ -29,45 +28,12 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
2928import org .apache .spark .mllib .tree .configuration .Algo ._
3029import org .apache .spark .mllib .tree .configuration .FeatureType ._
3130import org .apache .spark .mllib .tree .configuration .QuantileStrategy ._
32- import org .apache .spark .mllib .tree .impl .TreePoint
31+ import org .apache .spark .mllib .tree .impl .{ TimeTracker , TreePoint }
3332import org .apache .spark .mllib .tree .impurity .{Impurities , Gini , Entropy , Impurity }
3433import org .apache .spark .mllib .tree .model ._
3534import org .apache .spark .rdd .RDD
3635import org .apache .spark .util .random .XORShiftRandom
3736
38- class TimeTracker {
39-
40- var tmpTime : Long = Calendar .getInstance().getTimeInMillis
41-
42- def reset (): Unit = {
43- tmpTime = Calendar .getInstance().getTimeInMillis
44- }
45-
46- def elapsed (): Long = {
47- Calendar .getInstance().getTimeInMillis - tmpTime
48- }
49-
50- var initTime : Long = 0 // Data retag and cache
51- var findSplitsBinsTime : Long = 0
52- var extractNodeInfoTime : Long = 0
53- var extractInfoForLowerLevelsTime : Long = 0
54- var findBestSplitsTime : Long = 0
55- var findBinsForLevelTime : Long = 0
56- var binAggregatesTime : Long = 0
57- var chooseSplitsTime : Long = 0
58-
59- override def toString : String = {
60- s " DecisionTree timing \n " +
61- s " initTime: $initTime\n " +
62- s " findSplitsBinsTime: $findSplitsBinsTime\n " +
63- s " extractNodeInfoTime: $extractNodeInfoTime\n " +
64- s " extractInfoForLowerLevelsTime: $extractInfoForLowerLevelsTime\n " +
65- s " findBestSplitsTime: $findBestSplitsTime\n " +
66- s " findBinsForLevelTime: $findBinsForLevelTime\n " +
67- s " binAggregatesTime: $binAggregatesTime\n " +
68- s " chooseSplitsTime: $chooseSplitsTime\n "
69- }
70- }
7137
7238/**
7339 * :: Experimental ::
@@ -90,26 +56,26 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
9056 def train (input : RDD [LabeledPoint ]): DecisionTreeModel = {
9157
9258 val timer = new TimeTracker ()
93- timer.reset()
9459
60+ timer.start(" total" )
61+
62+ timer.start(" init" )
9563 // Cache input RDD for speedup during multiple passes.
9664 val retaggedInput = input.retag(classOf [LabeledPoint ])
9765 logDebug(" algo = " + strategy.algo)
98-
99- timer.initTime += timer.elapsed()
100- timer.reset()
66+ timer.stop(" init" )
10167
10268 // Find the splits and the corresponding bins (interval between the splits) using a sample
10369 // of the input data.
70+ timer.start(" findSplitsBins" )
10471 val (splits, bins) = DecisionTree .findSplitsBins(retaggedInput, strategy)
10572 val numBins = bins(0 ).length
73+ timer.stop(" findSplitsBins" )
10674 logDebug(" numBins = " + numBins)
10775
108- timer.findSplitsBinsTime += timer.elapsed()
109-
110- timer.reset()
76+ timer.start(" init" )
11177 val treeInput = TreePoint .convertToTreeRDD(retaggedInput, strategy, bins)
112- timer.initTime += timer.elapsed( )
78+ timer.stop( " init " )
11379
11480 // depth of the decision tree
11581 val maxDepth = strategy.maxDepth
@@ -166,21 +132,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
166132
167133
168134 // Find best split for all nodes at a level.
169- timer.reset( )
135+ timer.start( " findBestSplits " )
170136 val splitsStatsForLevel = DecisionTree .findBestSplits(treeInput, parentImpurities,
171137 strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
172- timer.findBestSplitsTime += timer.elapsed( )
138+ timer.stop( " findBestSplits " )
173139
174140 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
175- timer.reset( )
141+ timer.start( " extractNodeInfo " )
176142 // Extract info for nodes at the current level.
177143 extractNodeInfo(nodeSplitStats, level, index, nodes)
178- timer.extractNodeInfoTime += timer.elapsed( )
179- timer.reset( )
144+ timer.stop( " extractNodeInfo " )
145+ timer.start( " extractInfoForLowerLevels " )
180146 // Extract info for nodes at the next lower level.
181147 extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
182148 filters)
183- timer.extractInfoForLowerLevelsTime += timer.elapsed( )
149+ timer.stop( " extractInfoForLowerLevels " )
184150 logDebug(" final best split = " + nodeSplitStats._1)
185151 }
186152 require(math.pow(2 , level) == splitsStatsForLevel.length)
@@ -194,8 +160,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
194160 }
195161 }
196162
197- println(timer)
198-
199163 logDebug(" #####################################" )
200164 logDebug(" Extracting tree model" )
201165 logDebug(" #####################################" )
@@ -205,6 +169,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
205169 // Build the full tree using the node info calculated in the level-wise best split calculations.
206170 topNode.build(nodes)
207171
172+ timer.stop(" total" )
173+
174+ // println(timer) // Print internal timing info.
175+
208176 new DecisionTreeModel (topNode, strategy.algo)
209177 }
210178
@@ -252,7 +220,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
252220 // noting the parents filters for the child nodes
253221 val childFilter = new Filter (nodeSplitStats._1, if (i == 0 ) - 1 else 1 )
254222 filters(nodeIndex) = childFilter :: filters((nodeIndex - 1 ) / 2 )
255- // println(s"extractInfoForLowerLevels: Set filters(node:$nodeIndex): ${filters(nodeIndex).mkString(", ")}")
256223 for (filter <- filters(nodeIndex)) {
257224 logDebug(" Filter = " + filter)
258225 }
@@ -491,7 +458,6 @@ object DecisionTree extends Serializable with Logging {
491458 maxLevelForSingleGroup : Int ,
492459 timer : TimeTracker = new TimeTracker ): Array [(Split , InformationGainStats )] = {
493460 // split into groups to avoid memory overflow during aggregation
494- // println(s"findBestSplits: level = $level")
495461 if (level > maxLevelForSingleGroup) {
496462 // When information for all nodes at a given level cannot be stored in memory,
497463 // the nodes are divided into multiple groups at each level with the number of groups
@@ -681,7 +647,6 @@ object DecisionTree extends Serializable with Logging {
681647 val parentFilters = findParentFilters(nodeIndex)
682648 // Find out whether the sample qualifies for the particular node.
683649 val sampleValid = isSampleValid(parentFilters, treePoint)
684- // println(s"==>findBinsForLevel: node:$nodeIndex, valid=$sampleValid, parentFilters:${parentFilters.mkString(",")}")
685650 val shift = 1 + numFeatures * nodeIndex
686651 if (! sampleValid) {
687652 // Mark one bin as -1 is sufficient.
@@ -699,12 +664,12 @@ object DecisionTree extends Serializable with Logging {
699664 arr
700665 }
701666
702- timer.reset( )
667+ timer.start( " findBinsForLevel " )
703668
704669 // Find feature bins for all nodes at a level.
705670 val binMappedRDD = input.map(x => findBinsForLevel(x))
706671
707- timer.findBinsForLevelTime += timer.elapsed( )
672+ timer.stop( " findBinsForLevel " )
708673
709674 /**
710675 * Increment aggregate in location for (node, feature, bin, label).
@@ -752,7 +717,6 @@ object DecisionTree extends Serializable with Logging {
752717 label : Double ,
753718 agg : Array [Double ],
754719 rightChildShift : Int ): Unit = {
755- // println(s"-- updateBinForUnorderedFeature node:$nodeIndex, feature:$featureIndex, label:$label.")
756720 // Find the bin index for this feature.
757721 val arrIndex = 1 + numFeatures * nodeIndex + featureIndex
758722 val featureValue = arr(arrIndex).toInt
@@ -830,10 +794,6 @@ object DecisionTree extends Serializable with Logging {
830794 // Check whether the instance was valid for this nodeIndex.
831795 val validSignalIndex = 1 + numFeatures * nodeIndex
832796 val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
833- if (level == 1 ) {
834- val nodeFilterIndex = math.pow(2 , level).toInt - 1 + nodeIndex + groupShift
835- // println(s"-multiclassWithCategoricalBinSeqOp: filter: ${filters(nodeFilterIndex)}")
836- }
837797 if (isSampleValidForNode) {
838798 // actual class label
839799 val label = arr(0 )
@@ -954,39 +914,15 @@ object DecisionTree extends Serializable with Logging {
954914 combinedAggregate
955915 }
956916
957- timer.reset()
958917
959918 // Calculate bin aggregates.
919+ timer.start(" binAggregates" )
960920 val binAggregates = {
961921 binMappedRDD.aggregate(Array .fill[Double ](binAggregateLength)(0 ))(binSeqOp,binCombOp)
962922 }
923+ timer.stop(" binAggregates" )
963924 logDebug(" binAggregates.length = " + binAggregates.length)
964925
965- timer.binAggregatesTime += timer.elapsed()
966- // 2 * numClasses * numBins * numFeatures * numNodes for unordered features.
967- // (left/right, node, feature, bin, label)
968- /*
969- println(s"binAggregates:")
970- for (i <- Range(0,2)) {
971- for (n <- Range(0,numNodes)) {
972- for (f <- Range(0,numFeatures)) {
973- for (b <- Range(0,4)) {
974- for (c <- Range(0,numClasses)) {
975- val idx = i * numClasses * numBins * numFeatures * numNodes +
976- n * numClasses * numBins * numFeatures +
977- f * numBins * numFeatures +
978- b * numFeatures +
979- c
980- if (binAggregates(idx) != 0) {
981- println(s"\t ($i, c:$c, b:$b, f:$f, n:$n): ${binAggregates(idx)}")
982- }
983- }
984- }
985- }
986- }
987- }
988- */
989-
990926 /**
991927 * Calculates the information gain for all splits based upon left/right split aggregates.
992928 * @param leftNodeAgg left node aggregates
@@ -1027,7 +963,6 @@ object DecisionTree extends Serializable with Logging {
1027963 val totalCount = leftTotalCount + rightTotalCount
1028964 if (totalCount == 0 ) {
1029965 // Return arbitrary prediction.
1030- // println(s"BLAH: feature $featureIndex, split $splitIndex. totalCount == 0")
1031966 return new InformationGainStats (0 , topImpurity, topImpurity, topImpurity, 0 )
1032967 }
1033968
@@ -1054,9 +989,6 @@ object DecisionTree extends Serializable with Logging {
1054989 }
1055990
1056991 val predict = indexOfLargestArrayElement(leftRightCounts)
1057- if (predict == 0 && featureIndex == 0 && splitIndex == 0 ) {
1058- // println(s"AGHGHGHHGHG: leftCounts: ${leftCounts.mkString(",")}, rightCounts: ${rightCounts.mkString(",")}")
1059- }
1060992 val prob = leftRightCounts(predict) / totalCount
1061993
1062994 val leftImpurity = if (leftTotalCount == 0 ) {
@@ -1209,7 +1141,6 @@ object DecisionTree extends Serializable with Logging {
12091141 }
12101142 splitIndex += 1
12111143 }
1212- // println(s"found Agg: $TMPDEBUG")
12131144 }
12141145
12151146 def findAggForRegression (
@@ -1369,7 +1300,6 @@ object DecisionTree extends Serializable with Logging {
13691300 bestGainStats = gainStats
13701301 bestFeatureIndex = featureIndex
13711302 bestSplitIndex = splitIndex
1372- // println(s" feature $featureIndex UPGRADED split $splitIndex: ${splits(featureIndex)(splitIndex)}: gainstats: $gainStats")
13731303 }
13741304 splitIndex += 1
13751305 }
@@ -1414,7 +1344,7 @@ object DecisionTree extends Serializable with Logging {
14141344 }
14151345 }
14161346
1417- timer.reset( )
1347+ timer.start( " chooseSplits " )
14181348
14191349 // Calculate best splits for all nodes at a given level
14201350 val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
@@ -1427,10 +1357,9 @@ object DecisionTree extends Serializable with Logging {
14271357 val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
14281358 logDebug(" parent node impurity = " + parentNodeImpurity)
14291359 bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
1430- // println(s"bestSplits(node:$node): ${bestSplits(node)}")
14311360 node += 1
14321361 }
1433- timer.chooseSplitsTime += timer.elapsed( )
1362+ timer.stop( " chooseSplits " )
14341363
14351364 bestSplits
14361365 }
0 commit comments