@@ -87,7 +87,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
8787 topNode.build(nodes)
8888
8989 val decisionTreeModel = {
90- return new DecisionTreeModel (topNode)
90+ return new DecisionTreeModel (topNode, strategy.algo )
9191 }
9292
9393 return decisionTreeModel
@@ -98,14 +98,8 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
9898 val split = nodeSplitStats._1
9999 val stats = nodeSplitStats._2
100100 val nodeIndex = scala.math.pow(2 , level).toInt - 1 + index
101- val predict = {
102- val leftSamples = nodeSplitStats._2.leftSamples.toDouble
103- val rightSamples = nodeSplitStats._2.rightSamples.toDouble
104- val totalSamples = leftSamples + rightSamples
105- leftSamples / totalSamples
106- }
107101 val isLeaf = (stats.gain <= 0 ) || (level == strategy.maxDepth - 1 )
108- val node = new Node (nodeIndex, predict, isLeaf, Some (split), None , None , Some (stats))
102+ val node = new Node (nodeIndex, stats. predict, isLeaf, Some (split), None , None , Some (stats))
109103 logDebug(" Node = " + node)
110104 nodes(nodeIndex) = node
111105 }
@@ -370,8 +364,8 @@ object DecisionTree extends Serializable with Logging {
370364
371365 val impurity = if (level > 0 ) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
372366
373- if (leftCount == 0 ) return new InformationGainStats (0 ,topImpurity,Double .MinValue ,0 , topImpurity,rightCount.toLong )
374- if (rightCount == 0 ) return new InformationGainStats (0 ,topImpurity,topImpurity,leftCount.toLong, Double .MinValue ,0 )
367+ if (leftCount == 0 ) return new InformationGainStats (0 ,topImpurity,Double .MinValue ,topImpurity,1 )
368+ if (rightCount == 0 ) return new InformationGainStats (0 ,topImpurity,topImpurity,Double .MinValue ,0 )
375369
376370 val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
377371 val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
@@ -387,7 +381,9 @@ object DecisionTree extends Serializable with Logging {
387381 }
388382 }
389383
390- new InformationGainStats (gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
384+ val predict = leftCount / (leftCount + rightCount)
385+
386+ new InformationGainStats (gain,impurity,leftImpurity,rightImpurity,predict)
391387 }
392388 case Regression => {
393389 val leftCount = leftNodeAgg(featureIndex)(3 * index)
@@ -400,8 +396,8 @@ object DecisionTree extends Serializable with Logging {
400396
401397 val impurity = if (level > 0 ) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares)
402398
403- if (leftCount == 0 ) return new InformationGainStats (0 ,topImpurity,Double .MinValue ,0 , topImpurity,rightCount.toLong )
404- if (rightCount == 0 ) return new InformationGainStats (0 ,topImpurity,topImpurity,leftCount.toLong, Double .MinValue ,0 )
399+ if (leftCount == 0 ) return new InformationGainStats (0 ,topImpurity,Double .MinValue ,topImpurity,rightSum / rightCount)
400+ if (rightCount == 0 ) return new InformationGainStats (0 ,topImpurity,topImpurity,Double .MinValue ,leftSum / leftCount )
405401
406402 val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
407403 val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
@@ -417,7 +413,7 @@ object DecisionTree extends Serializable with Logging {
417413 }
418414 }
419415
420- new InformationGainStats (gain,impurity,leftImpurity,leftCount.toLong, rightImpurity,rightCount.toLong )
416+ new InformationGainStats (gain,impurity,leftImpurity,rightImpurity,(leftSum + rightSum) / (leftCount + rightCount) )
421417
422418 }
423419 }
@@ -515,7 +511,7 @@ object DecisionTree extends Serializable with Logging {
515511 var bestFeatureIndex = 0
516512 var bestSplitIndex = 0
517513 // Initialization with infeasible values
518- var bestGainStats = new InformationGainStats (Double .MinValue ,- 1.0 ,- 1.0 ,0 , - 1.0 ,0 )
514+ var bestGainStats = new InformationGainStats (Double .MinValue ,- 1.0 ,- 1.0 ,- 1.0 ,- 1 )
519515 for (featureIndex <- 0 until numFeatures) {
520516 for (splitIndex <- 0 until numSplits - 1 ){
521517 val gainStats = gains(featureIndex)(splitIndex)
0 commit comments