@@ -83,7 +83,7 @@ object DecisionTree extends Serializable {
8383 level : Int ,
8484 filters : Array [List [Filter ]],
8585 splits : Array [Array [Split ]],
86- bins : Array [Array [Bin ]]) : Array [(Split , Double , Long , Long )] = {
86+ bins : Array [Array [Bin ]]) : Array [(Split , InformationGainStats )] = {
8787
8888 // Common calculations for multiple nested methods
8989 val numNodes = scala.math.pow(2 , level).toInt
@@ -241,7 +241,7 @@ object DecisionTree extends Serializable {
241241 featureIndex : Int ,
242242 index : Int ,
243243 rightNodeAgg : Array [Array [Double ]],
244- topImpurity : Double ) : ( Double , Long , Long ) = {
244+ topImpurity : Double ) : InformationGainStats = {
245245
246246 val left0Count = leftNodeAgg(featureIndex)(2 * index)
247247 val left1Count = leftNodeAgg(featureIndex)(2 * index + 1 )
@@ -251,20 +251,22 @@ object DecisionTree extends Serializable {
251251 val right1Count = rightNodeAgg(featureIndex)(2 * index + 1 )
252252 val rightCount = right0Count + right1Count
253253
254- if (leftCount == 0 ) return (0 , leftCount.toLong, rightCount.toLong)
254+ if (leftCount == 0 ) return new InformationGainStats (0 ,topImpurity,Double .MinValue ,0 ,topImpurity,rightCount.toLong)
255+ if (rightCount == 0 ) return new InformationGainStats (0 ,topImpurity,topImpurity,leftCount.toLong,Double .MinValue ,0 )
255256
256257 // println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
257258 val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
258259
259- if (rightCount == 0 ) return (0 , leftCount.toLong, rightCount.toLong)
260260
261261 // println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
262262 val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
263263
264264 val leftWeight = leftCount.toDouble / (leftCount + rightCount)
265265 val rightWeight = rightCount.toDouble / (leftCount + rightCount)
266266
267- (topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong)
267+ val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
268+
269+ new InformationGainStats (gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
268270
269271 }
270272
@@ -307,9 +309,9 @@ object DecisionTree extends Serializable {
307309 }
308310
309311 def calculateGainsForAllNodeSplits (leftNodeAgg : Array [Array [Double ]], rightNodeAgg : Array [Array [Double ]], nodeImpurity : Double )
310- : Array [Array [( Double , Long , Long ) ]] = {
312+ : Array [Array [InformationGainStats ]] = {
311313
312- val gains = Array .ofDim[( Double , Long , Long ) ](numFeatures, numSplits - 1 )
314+ val gains = Array .ofDim[InformationGainStats ](numFeatures, numSplits - 1 )
313315
314316 for (featureIndex <- 0 until numFeatures) {
315317 for (index <- 0 until numSplits - 1 ) {
@@ -325,44 +327,43 @@ object DecisionTree extends Serializable {
325327
326328 @param binData Array[Double] of size 2*numSplits*numFeatures
327329 */
328- def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : (Split , Double , Long , Long ) = {
330+ def binsToBestSplit (binData : Array [Double ], nodeImpurity : Double ) : (Split , InformationGainStats ) = {
329331 println(" node impurity = " + nodeImpurity)
330332 val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
331333 val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
332334
333335 // println("gains.size = " + gains.size)
334336 // println("gains(0).size = " + gains(0).size)
335337
336- val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount ) = {
338+ val (bestFeatureIndex,bestSplitIndex, gainStats ) = {
337339 var bestFeatureIndex = 0
338340 var bestSplitIndex = 0
339- var maxGain = Double .MinValue
340- var leftSamples = Long .MinValue
341- var rightSamples = Long .MinValue
341+ // Initialization with infeasible values
342+ var bestGainStats = new InformationGainStats (- 1.0 ,- 1.0 ,- 1.0 ,0 ,- 1.0 ,0 )
343+ // var maxGain = Double.MinValue
344+ // var leftSamples = Long.MinValue
345+ // var rightSamples = Long.MinValue
342346 for (featureIndex <- 0 until numFeatures) {
343347 for (splitIndex <- 0 until numSplits - 1 ){
344- val gain = gains(featureIndex)(splitIndex)
348+ val gainStats = gains(featureIndex)(splitIndex)
345349 // println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
346- if (gain._1 > maxGain) {
347- maxGain = gain._1
348- leftSamples = gain._2
349- rightSamples = gain._3
350+ if (gainStats.gain > bestGainStats.gain) {
351+ bestGainStats = gainStats
350352 bestFeatureIndex = featureIndex
351353 bestSplitIndex = splitIndex
352354 println(" bestFeatureIndex = " + bestFeatureIndex + " , bestSplitIndex = " + bestSplitIndex
353- + " , maxGain = " + maxGain + " , leftSamples = " + leftSamples + " ,rightSamples = " + rightSamples )
355+ + " , gain stats = " + bestGainStats )
354356 }
355357 }
356358 }
357- (bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples )
359+ (bestFeatureIndex,bestSplitIndex,bestGainStats )
358360 }
359361
360- (splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount)
361- // TODO: Return array of node stats with split and impurity information
362+ (splits(bestFeatureIndex)(bestSplitIndex),gainStats)
362363 }
363364
364365 // Calculate best splits for all nodes at a given level
365- val bestSplits = new Array [(Split , Double , Long , Long )](numNodes)
366+ val bestSplits = new Array [(Split , InformationGainStats )](numNodes)
366367 for (node <- 0 until numNodes){
367368 val shift = 2 * node* numSplits* numFeatures
368369 val binsForNode = binAggregates.slice(shift,shift+ 2 * numSplits* numFeatures)
0 commit comments