@@ -50,23 +50,34 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
5050
5151 // Cache input RDD for speedup during multiple passes
5252 input.cache()
53+ logDebug(" algo = " + strategy.algo)
5354
55+ // Finding the splits and the corresponding bins (interval between the splits) using a sample
56+ // of the input data.
5457 val (splits, bins) = DecisionTree .findSplitsBins(input, strategy)
5558 logDebug(" numSplits = " + bins(0 ).length)
59+
60+ // Noting numBins for the input data
5661 strategy.numBins = bins(0 ).length
5762
63+ // The depth of the decision tree
5864 val maxDepth = strategy.maxDepth
59-
65+ // The max number of nodes possible given the depth of the tree
6066 val maxNumNodes = scala.math.pow(2 , maxDepth).toInt - 1
67+ // Initalizing an array to hold filters applied to points for each node
6168 val filters = new Array [List [Filter ]](maxNumNodes)
69+ // The filter at the top node is an empty list
6270 filters(0 ) = List ()
71+ // Initializing an array to hold parent impurity calculations for each node
6372 val parentImpurities = new Array [Double ](maxNumNodes)
6473 // Dummy value for top node (updated during first split calculation)
65- // parentImpurities(0) = Double.MinValue
6674 val nodes = new Array [Node ](maxNumNodes)
6775
68- logDebug(" algo = " + strategy.algo)
69-
76+ // The main-idea here is to perform level-wise training of the decision tree nodes thus
77+ // reducing the passes over the data from l to log2(l) where l is the total number of nodes.
78+ // Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
79+ // the sample is only used for the split calculation at the node if the sampled would have
80+ // still survived the filters of the parent nodes.
7081 breakable {
7182 for (level <- 0 until maxDepth) {
7283
@@ -79,36 +90,41 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
7990 level, filters, splits, bins)
8091
8192 for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
82-
93+ // Extract info for nodes at the current level
8394 extractNodeInfo(nodeSplitStats, level, index, nodes)
95+ // Extract info for nodes at the next lower level
8496 extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
8597 filters)
8698 logDebug(" final best split = " + nodeSplitStats._1)
8799
88100 }
89101 require(scala.math.pow(2 , level) == splitsStatsForLevel.length)
90-
102+ // Check whether all the nodes at the current level at leaves
91103 val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0 )
92104 logDebug(" all leaf = " + allLeaf)
93- if (allLeaf) break
105+ if (allLeaf) break // no more tree construction
94106
95107 }
96108 }
97109
110+ // Initialize the top or root node of the tree
98111 val topNode = nodes(0 )
112+ // Build the full tree using the node info calculated in the level-wise best split calculations
99113 topNode.build(nodes)
100114
101- val decisionTreeModel = {
102- return new DecisionTreeModel (topNode, strategy.algo)
103- }
104- return decisionTreeModel
115+ // Return a decision tree model
116+ return new DecisionTreeModel (topNode, strategy.algo)
105117 }
106118
107-
119+ /**
120+ * Extract the decision tree node information for th given tree level and node index
121+ */
108122 private def extractNodeInfo (
109123 nodeSplitStats : (Split , InformationGainStats ),
110- level : Int , index : Int ,
111- nodes : Array [Node ]) {
124+ level : Int ,
125+ index : Int ,
126+ nodes : Array [Node ])
127+ : Unit = {
112128
113129 val split = nodeSplitStats._1
114130 val stats = nodeSplitStats._2
@@ -119,35 +135,37 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
119135 nodes(nodeIndex) = node
120136 }
121137
138+ /**
139+ * Extract the decision tree node information for the children of the node
140+ */
122141 private def extractInfoForLowerLevels (
123142 level : Int ,
124143 index : Int ,
125144 maxDepth : Int ,
126145 nodeSplitStats : (Split , InformationGainStats ),
127146 parentImpurities : Array [Double ],
128- filters : Array [List [Filter ]]) {
147+ filters : Array [List [Filter ]])
148+ : Unit = {
129149
150+ // 0 corresponds to the left child node and 1 corresponds to the right child node.
130151 for (i <- 0 to 1 ) {
131-
152+ // Calculating the index of the node from the node level and the index at the current level
132153 val nodeIndex = scala.math.pow(2 , level + 1 ).toInt - 1 + 2 * index + i
133-
134154 if (level < maxDepth - 1 ) {
135-
136155 val impurity = if (i == 0 ) {
137156 nodeSplitStats._2.leftImpurity
138157 } else {
139158 nodeSplitStats._2.rightImpurity
140159 }
141-
142160 logDebug(" nodeIndex = " + nodeIndex + " , impurity = " + impurity)
161+ // noting the parent impurities
143162 parentImpurities(nodeIndex) = impurity
163+ // noting the parents filters for the child nodes
144164 val childFilter = new Filter (nodeSplitStats._1, if (i == 0 ) - 1 else 1 )
145165 filters(nodeIndex) = childFilter :: filters((nodeIndex - 1 ) / 2 )
146-
147166 for (filter <- filters(nodeIndex)) {
148167 logDebug(" Filter = " + filter)
149168 }
150-
151169 }
152170 }
153171 }
0 commit comments