Skip to content

Commit 4798aae

Browse files
committed
added gain stats class
Signed-off-by: Manish Amde <[email protected]>
1 parent dad0afc commit 4798aae

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ object DecisionTree extends Serializable {
8383
level: Int,
8484
filters : Array[List[Filter]],
8585
splits : Array[Array[Split]],
86-
bins : Array[Array[Bin]]) : Array[(Split, Double, Long, Long)] = {
86+
bins : Array[Array[Bin]]) : Array[(Split, InformationGainStats)] = {
8787

8888
//Common calculations for multiple nested methods
8989
val numNodes = scala.math.pow(2, level).toInt
@@ -241,7 +241,7 @@ object DecisionTree extends Serializable {
241241
featureIndex: Int,
242242
index: Int,
243243
rightNodeAgg: Array[Array[Double]],
244-
topImpurity: Double) : (Double, Long, Long) = {
244+
topImpurity: Double) : InformationGainStats = {
245245

246246
val left0Count = leftNodeAgg(featureIndex)(2 * index)
247247
val left1Count = leftNodeAgg(featureIndex)(2 * index + 1)
@@ -251,20 +251,22 @@ object DecisionTree extends Serializable {
251251
val right1Count = rightNodeAgg(featureIndex)(2 * index + 1)
252252
val rightCount = right0Count + right1Count
253253

254-
if (leftCount == 0) return (0, leftCount.toLong, rightCount.toLong)
254+
if (leftCount == 0) return new InformationGainStats(0,topImpurity,Double.MinValue,0,topImpurity,rightCount.toLong)
255+
if (rightCount == 0) return new InformationGainStats(0,topImpurity,topImpurity,leftCount.toLong,Double.MinValue,0)
255256

256257
//println("left0count = " + left0Count + ", left1count = " + left1Count + ", leftCount = " + leftCount)
257258
val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
258259

259-
if (rightCount == 0) return (0, leftCount.toLong, rightCount.toLong)
260260

261261
//println("right0count = " + right0Count + ", right1count = " + right1Count + ", rightCount = " + rightCount)
262262
val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
263263

264264
val leftWeight = leftCount.toDouble / (leftCount + rightCount)
265265
val rightWeight = rightCount.toDouble / (leftCount + rightCount)
266266

267-
(topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity, leftCount.toLong, rightCount.toLong)
267+
val gain = topImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity
268+
269+
new InformationGainStats(gain,topImpurity,leftImpurity,leftCount.toLong,rightImpurity,rightCount.toLong)
268270

269271
}
270272

@@ -307,9 +309,9 @@ object DecisionTree extends Serializable {
307309
}
308310

309311
def calculateGainsForAllNodeSplits(leftNodeAgg: Array[Array[Double]], rightNodeAgg: Array[Array[Double]], nodeImpurity: Double)
310-
: Array[Array[(Double,Long,Long)]] = {
312+
: Array[Array[InformationGainStats]] = {
311313

312-
val gains = Array.ofDim[(Double,Long,Long)](numFeatures, numSplits - 1)
314+
val gains = Array.ofDim[InformationGainStats](numFeatures, numSplits - 1)
313315

314316
for (featureIndex <- 0 until numFeatures) {
315317
for (index <- 0 until numSplits -1) {
@@ -325,44 +327,43 @@ object DecisionTree extends Serializable {
325327
326328
@param binData Array[Double] of size 2*numSplits*numFeatures
327329
*/
328-
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, Double, Long, Long) = {
330+
def binsToBestSplit(binData : Array[Double], nodeImpurity : Double) : (Split, InformationGainStats) = {
329331
println("node impurity = " + nodeImpurity)
330332
val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData)
331333
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
332334

333335
//println("gains.size = " + gains.size)
334336
//println("gains(0).size = " + gains(0).size)
335337

336-
val (bestFeatureIndex,bestSplitIndex, gain, leftCount, rightCount) = {
338+
val (bestFeatureIndex,bestSplitIndex, gainStats) = {
337339
var bestFeatureIndex = 0
338340
var bestSplitIndex = 0
339-
var maxGain = Double.MinValue
340-
var leftSamples = Long.MinValue
341-
var rightSamples = Long.MinValue
341+
//Initialization with infeasible values
342+
var bestGainStats = new InformationGainStats(-1.0,-1.0,-1.0,0,-1.0,0)
343+
// var maxGain = Double.MinValue
344+
// var leftSamples = Long.MinValue
345+
// var rightSamples = Long.MinValue
342346
for (featureIndex <- 0 until numFeatures) {
343347
for (splitIndex <- 0 until numSplits - 1){
344-
val gain = gains(featureIndex)(splitIndex)
348+
val gainStats = gains(featureIndex)(splitIndex)
345349
//println("featureIndex = " + featureIndex + ", splitIndex = " + splitIndex + ", gain = " + gain)
346-
if(gain._1 > maxGain) {
347-
maxGain = gain._1
348-
leftSamples = gain._2
349-
rightSamples = gain._3
350+
if(gainStats.gain > bestGainStats.gain) {
351+
bestGainStats = gainStats
350352
bestFeatureIndex = featureIndex
351353
bestSplitIndex = splitIndex
352354
println("bestFeatureIndex = " + bestFeatureIndex + ", bestSplitIndex = " + bestSplitIndex
353-
+ ", maxGain = " + maxGain + ", leftSamples = " + leftSamples + ",rightSamples = " + rightSamples)
355+
+ ", gain stats = " + bestGainStats)
354356
}
355357
}
356358
}
357-
(bestFeatureIndex,bestSplitIndex,maxGain,leftSamples,rightSamples)
359+
(bestFeatureIndex,bestSplitIndex,bestGainStats)
358360
}
359361

360-
(splits(bestFeatureIndex)(bestSplitIndex),gain,leftCount,rightCount)
361-
//TODO: Return array of node stats with split and impurity information
362+
(splits(bestFeatureIndex)(bestSplitIndex),gainStats)
362363
}
363364

364365
//Calculate best splits for all nodes at a given level
365-
val bestSplits = new Array[(Split, Double, Long, Long)](numNodes)
366+
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
366367
for (node <- 0 until numNodes){
367368
val shift = 2*node*numSplits*numFeatures
368369
val binsForNode = binAggregates.slice(shift,shift+2*numSplits*numFeatures)

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
7272
assert(bestSplits.length == 1)
7373
assert(0==bestSplits(0)._1.feature)
7474
assert(10==bestSplits(0)._1.threshold)
75-
assert(0==bestSplits(0)._2)
76-
assert(10==bestSplits(0)._3)
77-
assert(990==bestSplits(0)._4)
75+
assert(0==bestSplits(0)._2.gain)
76+
assert(10==bestSplits(0)._2.leftSamples)
77+
assert(0==bestSplits(0)._2.leftImpurity)
78+
assert(990==bestSplits(0)._2.rightSamples)
79+
assert(0==bestSplits(0)._2.rightImpurity)
7880
}
7981

8082
test("stump with fixed label 1 for Gini"){
@@ -93,9 +95,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
9395
assert(bestSplits.length == 1)
9496
assert(0==bestSplits(0)._1.feature)
9597
assert(10==bestSplits(0)._1.threshold)
96-
assert(0==bestSplits(0)._2)
97-
assert(10==bestSplits(0)._3)
98-
assert(990==bestSplits(0)._4)
98+
assert(0==bestSplits(0)._2.gain)
99+
assert(10==bestSplits(0)._2.leftSamples)
100+
assert(0==bestSplits(0)._2.leftImpurity)
101+
assert(990==bestSplits(0)._2.rightSamples)
102+
assert(0==bestSplits(0)._2.rightImpurity)
99103
}
100104

101105

@@ -115,9 +119,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
115119
assert(bestSplits.length == 1)
116120
assert(0==bestSplits(0)._1.feature)
117121
assert(10==bestSplits(0)._1.threshold)
118-
assert(0==bestSplits(0)._2)
119-
assert(10==bestSplits(0)._3)
120-
assert(990==bestSplits(0)._4)
122+
assert(0==bestSplits(0)._2.gain)
123+
assert(10==bestSplits(0)._2.leftSamples)
124+
assert(0==bestSplits(0)._2.leftImpurity)
125+
assert(990==bestSplits(0)._2.rightSamples)
126+
assert(0==bestSplits(0)._2.rightImpurity)
121127
}
122128

123129
test("stump with fixed label 1 for Entropy"){
@@ -136,9 +142,11 @@ class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
136142
assert(bestSplits.length == 1)
137143
assert(0==bestSplits(0)._1.feature)
138144
assert(10==bestSplits(0)._1.threshold)
139-
assert(0==bestSplits(0)._2)
140-
assert(10==bestSplits(0)._3)
141-
assert(990==bestSplits(0)._4)
145+
assert(0==bestSplits(0)._2.gain)
146+
assert(10==bestSplits(0)._2.leftSamples)
147+
assert(0==bestSplits(0)._2.leftImpurity)
148+
assert(990==bestSplits(0)._2.rightSamples)
149+
assert(0==bestSplits(0)._2.rightImpurity)
142150
}
143151

144152

0 commit comments

Comments
 (0)