Skip to content

Commit 6df35b9

Browse files
committed
regression predict logic
Signed-off-by: Manish Amde <[email protected]>
1 parent 53108ed commit 6df35b9

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
8787
topNode.build(nodes)
8888

8989
val decisionTreeModel = {
90-
return new DecisionTreeModel(topNode)
90+
return new DecisionTreeModel(topNode, strategy.algo)
9191
}
9292

9393
return decisionTreeModel
@@ -98,14 +98,8 @@ class DecisionTree(val strategy : Strategy) extends Serializable with Logging {
9898
val split = nodeSplitStats._1
9999
val stats = nodeSplitStats._2
100100
val nodeIndex = scala.math.pow(2, level).toInt - 1 + index
101-
val predict = {
102-
val leftSamples = nodeSplitStats._2.leftSamples.toDouble
103-
val rightSamples = nodeSplitStats._2.rightSamples.toDouble
104-
val totalSamples = leftSamples + rightSamples
105-
leftSamples / totalSamples
106-
}
107101
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
108-
val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
102+
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
109103
logDebug("Node = " + node)
110104
nodes(nodeIndex) = node
111105
}
@@ -370,8 +364,8 @@ object DecisionTree extends Serializable with Logging {
370364

371365
val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
372366

373-
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
374-
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)
367+
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,1)
368+
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,0)
375369

376370
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
377371
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
@@ -387,7 +381,9 @@ object DecisionTree extends Serializable with Logging {
387381
}
388382
}
389383

390-
new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
384+
val predict = leftCount / (leftCount + rightCount)
385+
386+
new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,predict)
391387
}
392388
case Regression => {
393389
val leftCount = leftNodeAgg(featureIndex)(3 * index)
@@ -400,8 +396,8 @@ object DecisionTree extends Serializable with Logging {
400396

401397
val impurity = if (level > 0) topImpurity else strategy.impurity.calculate(leftCount + rightCount, leftSum + rightSum, leftSumSquares + rightSumSquares)
402398

403-
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
404-
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)
399+
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,topImpurity,rightSum/rightCount)
400+
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,Double.MinValue,leftSum/leftCount)
405401

406402
val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares)
407403
val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares)
@@ -417,7 +413,7 @@ object DecisionTree extends Serializable with Logging {
417413
}
418414
}
419415

420-
new InformationGainStats(gain,impurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
416+
new InformationGainStats(gain,impurity,leftImpurity,rightImpurity,(leftSum + rightSum)/(leftCount+rightCount))
421417

422418
}
423419
}
@@ -515,7 +511,7 @@ object DecisionTree extends Serializable with Logging {
515511
var bestFeatureIndex = 0
516512
var bestSplitIndex = 0
517513
//Initialization with infeasible values
518-
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,0,-1.0,0)
514+
var bestGainStats = new InformationGainStats(Double.MinValue,-1.0,-1.0,-1.0,-1)
519515
for (featureIndex <- 0 until numFeatures) {
520516
for (splitIndex <- 0 until numSplits - 1){
521517
val gainStats = gains(featureIndex)(splitIndex)

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,19 @@
1717
package org.apache.spark.mllib.tree.model
1818

1919
import org.apache.spark.mllib.regression.LabeledPoint
20+
import org.apache.spark.mllib.tree.configuration.Algo._
2021

21-
class DecisionTreeModel(val topNode : Node) extends Serializable {
22+
class DecisionTreeModel(val topNode : Node, val algo : Algo) extends Serializable {
2223

23-
def predict(features : Array[Double]) = if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0
24+
def predict(features : Array[Double]) = {
25+
algo match {
26+
case Classification => {
27+
if (topNode.predictIfLeaf(features) >= 0.5) 0.0 else 1.0
28+
}
29+
case Regression => {
30+
topNode.predictIfLeaf(features)
31+
}
32+
}
33+
}
2434

2535
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package org.apache.spark.mllib.tree.model
1919
class InformationGainStats(val gain : Double,
2020
val impurity: Double,
2121
val leftImpurity : Double,
22-
val leftSamples : Long,
22+
//val leftSamples : Long,
2323
val rightImpurity : Double,
24-
val rightSamples : Long) extends Serializable {
24+
//val rightSamples : Long
25+
val predict : Double) extends Serializable {
2526

2627
override def toString =
2728
"gain = " + gain + ", impurity = " + impurity + ", left impurity = "
28-
+ leftImpurity + ", leftSamples = " + leftSamples + ", right impurity = "
29-
+ rightImpurity + ", rightSamples = " + rightSamples
29+
+ leftImpurity + ", right impurity = " + rightImpurity + ", predict = " + predict
3030

3131

3232
}

0 commit comments

Comments
 (0)