1717
1818package org .apache .spark .mllib .tree
1919
20-
2120import scala .collection .JavaConverters ._
2221
2322import org .apache .spark .annotation .Experimental
@@ -32,6 +31,7 @@ import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
3231import org .apache .spark .mllib .tree .impurity .{Impurities , Impurity }
3332import org .apache .spark .mllib .tree .model ._
3433import org .apache .spark .rdd .RDD
34+ import org .apache .spark .storage .StorageLevel
3535import org .apache .spark .util .random .XORShiftRandom
3636
3737
@@ -59,11 +59,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
5959
6060 timer.start(" total" )
6161
62- // Cache input RDD for speedup during multiple passes.
6362 timer.start(" init" )
63+
6464 val retaggedInput = input.retag(classOf [LabeledPoint ])
6565 logDebug(" algo = " + strategy.algo)
66- timer.stop(" init" )
6766
6867 // Find the splits and the corresponding bins (interval between the splits) using a sample
6968 // of the input data.
@@ -73,9 +72,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7372 timer.stop(" findSplitsBins" )
7473 logDebug(" numBins = " + numBins)
7574
76- timer.start( " init " )
77- val treeInput = TreePoint .convertToTreeRDD(retaggedInput, strategy, bins).cache()
78- timer.stop( " init " )
75+ // Cache input RDD for speedup during multiple passes.
76+ val treeInput = TreePoint .convertToTreeRDD(retaggedInput, strategy, bins)
77+ .persist( StorageLevel . MEMORY_AND_DISK )
7978
8079 // depth of the decision tree
8180 val maxDepth = strategy.maxDepth
@@ -90,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
9089 // dummy value for top node (updated during first split calculation)
9190 val nodes = new Array [Node ](maxNumNodes)
9291 // num features
93- val numFeatures = treeInput.take(1 )(0 ).features .size
92+ val numFeatures = treeInput.take(1 )(0 ).binnedFeatures .size
9493
9594 // Calculate level for single group construction
9695
@@ -110,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
110109 (math.log(maxNumberOfNodesPerGroup) / math.log(2 )).floor.toInt, 0 )
111110 logDebug(" max level for single group = " + maxLevelForSingleGroup)
112111
112+ timer.stop(" init" )
113+
113114 /*
114115 * The main idea here is to perform level-wise training of the decision tree nodes thus
115116 * reducing the passes over the data from l to log2(l) where l is the total number of nodes.
@@ -126,7 +127,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
126127 logDebug(" level = " + level)
127128 logDebug(" #####################################" )
128129
129-
130130 // Find best split for all nodes at a level.
131131 timer.start(" findBestSplits" )
132132 val splitsStatsForLevel = DecisionTree .findBestSplits(treeInput, parentImpurities,
@@ -167,8 +167,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
167167
168168 timer.stop(" total" )
169169
170- logDebug (" Internal timing for DecisionTree:" )
171- logDebug (s " $timer" )
170+ logInfo (" Internal timing for DecisionTree:" )
171+ logInfo (s " $timer" )
172172
173173 new DecisionTreeModel (topNode, strategy.algo)
174174 }
@@ -226,7 +226,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
226226 }
227227}
228228
229-
230229object DecisionTree extends Serializable with Logging {
231230
232231 /**
@@ -536,7 +535,7 @@ object DecisionTree extends Serializable with Logging {
536535 logDebug(" numNodes = " + numNodes)
537536
538537 // Find the number of features by looking at the first sample.
539- val numFeatures = input.first().features .size
538+ val numFeatures = input.first().binnedFeatures .size
540539 logDebug(" numFeatures = " + numFeatures)
541540
542541 // numBins: Number of bins = 1 + number of possible splits
@@ -578,12 +577,12 @@ object DecisionTree extends Serializable with Logging {
578577 }
579578
580579 // Apply each filter and check sample validity. Return false when invalid condition found.
581- for (filter <- parentFilters) {
580+ parentFilters.foreach { filter =>
582581 val featureIndex = filter.split.feature
583582 val comparison = filter.comparison
584583 val isFeatureContinuous = filter.split.featureType == Continuous
585584 if (isFeatureContinuous) {
586- val binId = treePoint.features (featureIndex)
585+ val binId = treePoint.binnedFeatures (featureIndex)
587586 val bin = bins(featureIndex)(binId)
588587 val featureValue = bin.highSplit.threshold
589588 val threshold = filter.split.threshold
@@ -598,9 +597,9 @@ object DecisionTree extends Serializable with Logging {
598597 val isUnorderedFeature =
599598 isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
600599 val featureValue = if (isUnorderedFeature) {
601- treePoint.features (featureIndex)
600+ treePoint.binnedFeatures (featureIndex)
602601 } else {
603- val binId = treePoint.features (featureIndex)
602+ val binId = treePoint.binnedFeatures (featureIndex)
604603 bins(featureIndex)(binId).category
605604 }
606605 val containsFeature = filter.split.categories.contains(featureValue)
@@ -648,9 +647,8 @@ object DecisionTree extends Serializable with Logging {
648647 arr(shift) = InvalidBinIndex
649648 } else {
650649 var featureIndex = 0
651- // TODO: Vectorize this
652650 while (featureIndex < numFeatures) {
653- arr(shift + featureIndex) = treePoint.features (featureIndex)
651+ arr(shift + featureIndex) = treePoint.binnedFeatures (featureIndex)
654652 featureIndex += 1
655653 }
656654 }
@@ -660,9 +658,8 @@ object DecisionTree extends Serializable with Logging {
660658 }
661659
662660 // Find feature bins for all nodes at a level.
663- timer.start(" findBinsForLevel " )
661+ timer.start(" aggregation " )
664662 val binMappedRDD = input.map(x => findBinsForLevel(x))
665- timer.stop(" findBinsForLevel" )
666663
667664 /**
668665 * Increment aggregate in location for (node, feature, bin, label).
@@ -907,13 +904,11 @@ object DecisionTree extends Serializable with Logging {
907904 combinedAggregate
908905 }
909906
910-
911907 // Calculate bin aggregates.
912- timer.start(" binAggregates" )
913908 val binAggregates = {
914909 binMappedRDD.aggregate(Array .fill[Double ](binAggregateLength)(0 ))(binSeqOp,binCombOp)
915910 }
916- timer.stop(" binAggregates " )
911+ timer.stop(" aggregation " )
917912 logDebug(" binAggregates.length = " + binAggregates.length)
918913
919914 /**
@@ -1225,12 +1220,16 @@ object DecisionTree extends Serializable with Logging {
12251220 nodeImpurity : Double ): Array [Array [InformationGainStats ]] = {
12261221 val gains = Array .ofDim[InformationGainStats ](numFeatures, numBins - 1 )
12271222
1228- for (featureIndex <- 0 until numFeatures) {
1223+ var featureIndex = 0
1224+ while (featureIndex < numFeatures) {
12291225 val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
1230- for (splitIndex <- 0 until numSplitsForFeature) {
1226+ var splitIndex = 0
1227+ while (splitIndex < numSplitsForFeature) {
12311228 gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
12321229 splitIndex, rightNodeAgg, nodeImpurity)
1230+ splitIndex += 1
12331231 }
1232+ featureIndex += 1
12341233 }
12351234 gains
12361235 }
0 commit comments