Skip to content

Commit d1ef4f6

Browse files
committed
more documentation
1 parent 794ff4d commit d1ef4f6

File tree

1 file changed

+39
-21
lines changed

1 file changed

+39
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,34 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
5050

5151
//Cache input RDD for speedup during multiple passes
5252
input.cache()
53+
logDebug("algo = " + strategy.algo)
5354

55+
//Finding the splits and the corresponding bins (interval between the splits) using a sample
56+
// of the input data.
5457
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
5558
logDebug("numSplits = " + bins(0).length)
59+
60+
//Noting numBins for the input data
5661
strategy.numBins = bins(0).length
5762

63+
//The depth of the decision tree
5864
val maxDepth = strategy.maxDepth
59-
65+
//The max number of nodes possible given the depth of the tree
6066
val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1
67+
//Initalizing an array to hold filters applied to points for each node
6168
val filters = new Array[List[Filter]](maxNumNodes)
69+
//The filter at the top node is an empty list
6270
filters(0) = List()
71+
//Initializing an array to hold parent impurity calculations for each node
6372
val parentImpurities = new Array[Double](maxNumNodes)
6473
//Dummy value for top node (updated during first split calculation)
65-
//parentImpurities(0) = Double.MinValue
6674
val nodes = new Array[Node](maxNumNodes)
6775

68-
logDebug("algo = " + strategy.algo)
69-
76+
//The main-idea here is to perform level-wise training of the decision tree nodes thus
77+
// reducing the passes over the data from l to log2(l) where l is the total number of nodes.
78+
// Each data sample is checked for validity w.r.t to each node at a given level -- i.e.,
79+
// the sample is only used for the split calculation at the node if the sampled would have
80+
// still survived the filters of the parent nodes.
7081
breakable {
7182
for (level <- 0 until maxDepth) {
7283

@@ -79,36 +90,41 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
7990
level, filters, splits, bins)
8091

8192
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
82-
93+
//Extract info for nodes at the current level
8394
extractNodeInfo(nodeSplitStats, level, index, nodes)
95+
//Extract info for nodes at the next lower level
8496
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
8597
filters)
8698
logDebug("final best split = " + nodeSplitStats._1)
8799

88100
}
89101
require(scala.math.pow(2, level) == splitsStatsForLevel.length)
90-
102+
//Check whether all the nodes at the current level at leaves
91103
val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
92104
logDebug("all leaf = " + allLeaf)
93-
if (allLeaf) break
105+
if (allLeaf) break //no more tree construction
94106

95107
}
96108
}
97109

110+
//Initialize the top or root node of the tree
98111
val topNode = nodes(0)
112+
//Build the full tree using the node info calculated in the level-wise best split calculations
99113
topNode.build(nodes)
100114

101-
val decisionTreeModel = {
102-
return new DecisionTreeModel(topNode, strategy.algo)
103-
}
104-
return decisionTreeModel
115+
//Return a decision tree model
116+
return new DecisionTreeModel(topNode, strategy.algo)
105117
}
106118

107-
119+
/**
120+
* Extract the decision tree node information for th given tree level and node index
121+
*/
108122
private def extractNodeInfo(
109123
nodeSplitStats: (Split, InformationGainStats),
110-
level: Int, index: Int,
111-
nodes: Array[Node]) {
124+
level: Int,
125+
index: Int,
126+
nodes: Array[Node])
127+
: Unit = {
112128

113129
val split = nodeSplitStats._1
114130
val stats = nodeSplitStats._2
@@ -119,35 +135,37 @@ class DecisionTree private(val strategy: Strategy) extends Serializable with Log
119135
nodes(nodeIndex) = node
120136
}
121137

138+
/**
139+
* Extract the decision tree node information for the children of the node
140+
*/
122141
private def extractInfoForLowerLevels(
123142
level: Int,
124143
index: Int,
125144
maxDepth: Int,
126145
nodeSplitStats: (Split, InformationGainStats),
127146
parentImpurities: Array[Double],
128-
filters: Array[List[Filter]]) {
147+
filters: Array[List[Filter]])
148+
: Unit = {
129149

150+
// 0 corresponds to the left child node and 1 corresponds to the right child node.
130151
for (i <- 0 to 1) {
131-
152+
//Calculating the index of the node from the node level and the index at the current level
132153
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
133-
134154
if (level < maxDepth - 1) {
135-
136155
val impurity = if (i == 0) {
137156
nodeSplitStats._2.leftImpurity
138157
} else {
139158
nodeSplitStats._2.rightImpurity
140159
}
141-
142160
logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
161+
//noting the parent impurities
143162
parentImpurities(nodeIndex) = impurity
163+
//noting the parents filters for the child nodes
144164
val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
145165
filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
146-
147166
for (filter <- filters(nodeIndex)) {
148167
logDebug("Filter = " + filter)
149168
}
150-
151169
}
152170
}
153171
}

0 commit comments

Comments
 (0)