1717
1818package org .apache .spark .mllib .tree
1919
20+ import org .apache .spark .SparkContext ._
2021import org .apache .spark .rdd .RDD
22+ import org .apache .spark .mllib .tree .model ._
23+ import org .apache .spark .Logging
2124import org .apache .spark .mllib .regression .LabeledPoint
22- import org .apache .spark .mllib .tree .model .{ Split , Bin , DecisionTreeModel }
25+ import org .apache .spark .mllib .tree .model .Split
2326
2427
2528class DecisionTree (val strategy : Strategy ) {
@@ -30,25 +33,180 @@ class DecisionTree(val strategy : Strategy) {
3033 input.cache()
3134
3235 // TODO: Find all splits and bins using quantiles including support for categorical features, single-pass
36+ // TODO: Think about broadcasting this
3337 val (splits, bins) = DecisionTree .find_splits_bins(input, strategy)
3438
3539 // TODO: Level-wise training of tree and obtain Decision Tree model
3640
41+ val maxDepth = strategy.maxDepth
42+
43+ val maxNumNodes = scala.math.pow(2 ,maxDepth).toInt - 1
44+ val filters = new Array [List [Filter ]](maxNumNodes)
45+
46+ for (level <- 0 until maxDepth){
47+ // Find best split for all nodes at a level
48+ val numNodes = scala.math.pow(2 ,level).toInt
49+ val bestSplits = DecisionTree .findBestSplits(input, strategy, level, filters,splits,bins)
50+ // TODO: update filters and decision tree model
51+ }
3752
3853 return new DecisionTreeModel ()
3954 }
4055
4156}
4257
43- object DecisionTree {
58+ object DecisionTree extends Logging {
59+
60+ def findBestSplits (
61+ input : RDD [LabeledPoint ],
62+ strategy : Strategy ,
63+ level : Int ,
64+ filters : Array [List [Filter ]],
65+ splits : Array [Array [Split ]],
66+ bins : Array [Array [Bin ]]) : Array [Split ] = {
67+
68+ def findParentFilters (nodeIndex : Int ): List [Filter ] = {
69+ if (level == 0 ) {
70+ List [Filter ]()
71+ } else {
72+ val nodeFilterIndex = scala.math.pow(2 , level).toInt + nodeIndex
73+ val parentFilterIndex = nodeFilterIndex / 2
74+ filters(parentFilterIndex)
75+ }
76+ }
77+
78+ def isSampleValid (parentFilters : List [Filter ], labeledPoint : LabeledPoint ): Boolean = {
79+
80+ for (filter <- parentFilters) {
81+ val features = labeledPoint.features
82+ val featureIndex = filter.split.feature
83+ val threshold = filter.split.threshold
84+ val comparison = filter.comparison
85+ comparison match {
86+ case (- 1 ) => if (features(featureIndex) > threshold) return false
87+ case (0 ) => if (features(featureIndex) != threshold) return false
88+ case (1 ) => if (features(featureIndex) <= threshold) return false
89+ }
90+ }
91+ true
92+ }
93+
94+ def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
95+
96+ // TODO: Do binary search
97+ for (binIndex <- 0 until strategy.numSplits) {
98+ val bin = bins(featureIndex)(binIndex)
99+ // TODO: Remove this requirement post basic functional testing
100+ require(bin.lowSplit.feature == featureIndex)
101+ require(bin.highSplit.feature == featureIndex)
102+ val lowThreshold = bin.lowSplit.threshold
103+ val highThreshold = bin.highSplit.threshold
104+ val features = labeledPoint.features
105+ if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) {
106+ return binIndex
107+ }
108+ }
109+ throw new UnknownError (" no bin was found." )
110+
111+ }
112+ def findBinsForLevel : Array [Double ] = {
113+
114+ val numNodes = scala.math.pow(2 , level).toInt
115+ // Find the number of features by looking at the first sample
116+ val numFeatures = input.take(1 )(0 ).features.length
117+
118+ // TODO: Bit pack more by removing redundant label storage
119+ // calculating bin index and label per feature per node
120+ val arr = new Array [Double ](2 * numFeatures * numNodes)
121+ for (nodeIndex <- 0 until numNodes) {
122+ val parentFilters = findParentFilters(nodeIndex)
123+ // Find out whether the sample qualifies for the particular node
124+ val sampleValid = isSampleValid(parentFilters, labeledPoint)
125+ val shift = 2 * numFeatures * nodeIndex
126+ if (sampleValid) {
127+ // Add to invalid bin index -1
128+ for (featureIndex <- shift until (shift + numFeatures) by 2 ) {
129+ arr(featureIndex + 1 ) = - 1
130+ arr(featureIndex + 2 ) = labeledPoint.label
131+ }
132+ } else {
133+ for (featureIndex <- 0 until numFeatures) {
134+ arr(shift + (featureIndex * 2 ) + 1 ) = findBin(featureIndex, labeledPoint)
135+ arr(shift + (featureIndex * 2 ) + 2 ) = labeledPoint.label
136+ }
137+ }
138+
139+ }
140+ arr
141+ }
142+
143+ val binMappedRDD = input.map(labeledPoint => findBinsForLevel)
144+ // calculate bin aggregates
145+ // find best split
146+
147+
148+ Array [Split ]()
149+ }
150+
44151 def find_splits_bins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
152+
45153 val numSplits = strategy.numSplits
154+ logDebug(" numSplits = " + numSplits)
155+
156+ // Calculate the number of sample for approximate quantile calculation
46157 // TODO: Justify this calculation
47- val requiredSamples : Long = numSplits* numSplits
48- val count : Long = input.count()
49- val numSamples : Long = if (requiredSamples < count) requiredSamples else count
158+ val requiredSamples = numSplits* numSplits
159+ val count = input.count()
160+ val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
161+ logDebug(" fraction of data used for calculating quantiles = " + fraction)
162+
163+ // sampled input for RDD calculation
164+ val sampledInput = input.sample(false , fraction, 42 ).collect()
165+ val numSamples = sampledInput.length
166+
167+ require(numSamples > numSplits, " length of input samples should be greater than numSplits" )
168+
169+ // Find the number of features by looking at the first sample
50170 val numFeatures = input.take(1 )(0 ).features.length
51- (Array .ofDim[Split ](numFeatures,numSplits),Array .ofDim[Bin ](numFeatures,numSplits))
171+
172+ strategy.quantileCalculationStrategy match {
173+ case " sort" => {
174+ val splits = Array .ofDim[Split ](numFeatures,numSplits- 1 )
175+ val bins = Array .ofDim[Bin ](numFeatures,numSplits)
176+
177+ // Find all splits
178+ for (featureIndex <- 0 until numFeatures){
179+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
180+ val stride : Double = numSamples.toDouble/ numSplits
181+ for (index <- 0 until numSplits- 1 ) {
182+ val sampleIndex = (index+ 1 )* stride.toInt
183+ val split = new Split (featureIndex,featureSamples(sampleIndex)," continuous" )
184+ splits(featureIndex)(index) = split
185+ }
186+ }
187+
188+ // Find all bins
189+ for (featureIndex <- 0 until numFeatures){
190+ bins(featureIndex)(0 )
191+ = new Bin (new DummyLowSplit (" continuous" ),splits(featureIndex)(0 )," continuous" )
192+ for (index <- 1 until numSplits - 1 ){
193+ val bin = new Bin (splits(featureIndex)(index- 1 ),splits(featureIndex)(index)," continuous" )
194+ bins(featureIndex)(index) = bin
195+ }
196+ bins(featureIndex)(numSplits- 1 )
197+ = new Bin (splits(featureIndex)(numSplits- 3 ),new DummyHighSplit (" continuous" )," continuous" )
198+ }
199+
200+ (splits,bins)
201+ }
202+ case " minMax" => {
203+ (Array .ofDim[Split ](numFeatures,numSplits),Array .ofDim[Bin ](numFeatures,numSplits+ 2 ))
204+ }
205+ case " approximateHistogram" => {
206+ throw new UnsupportedOperationException (" approximate histogram not supported yet." )
207+ }
208+
209+ }
52210 }
53211
54212}
0 commit comments