Skip to content

Commit ff34845

Browse files
author
qiping.lqp
committed
separate calculation of predict of node from calculation of info gain
1 parent ac42378 commit ff34845

File tree

4 files changed

+82
-29
lines changed

4 files changed

+82
-29
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
130130

131131
// Find best split for all nodes at a level.
132132
timer.start("findBestSplits")
133-
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
133+
val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
134134
DecisionTree.findBestSplits(treeInput, parentImpurities,
135135
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
136136
timer.stop("findBestSplits")
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
143143
timer.start("extractNodeInfo")
144144
val split = nodeSplitStats._1
145145
val stats = nodeSplitStats._2
146+
val predict = nodeSplitStats._3
146147
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
147-
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
148+
val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
148149
logDebug("Node = " + node)
149150
nodes(nodeIndex) = node
150151
timer.stop("extractNodeInfo")
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
425426
splits: Array[Array[Split]],
426427
bins: Array[Array[Bin]],
427428
maxLevelForSingleGroup: Int,
428-
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
429+
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
429430
// split into groups to avoid memory overflow during aggregation
430431
if (level > maxLevelForSingleGroup) {
431432
// When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
434435
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
435436
val numGroups = 1 << level - maxLevelForSingleGroup
436437
logDebug("numGroups = " + numGroups)
437-
var bestSplits = new Array[(Split, InformationGainStats)](0)
438+
var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
438439
// Iterate over each group of nodes at a level.
439440
var groupIndex = 0
440441
while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
605606
bins: Array[Array[Bin]],
606607
timer: TimeTracker,
607608
numGroups: Int = 1,
608-
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
609+
groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
609610

610611
/*
611612
* The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
705706

706707
// Calculate best splits for all nodes at a given level
707708
timer.start("chooseSplits")
708-
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
709+
val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
709710
// Iterating over all nodes at this level
710711
var nodeIndex = 0
711712
while (nodeIndex < numNodes) {
@@ -747,18 +748,16 @@ object DecisionTree extends Serializable with Logging {
747748

748749
val totalCount = leftCount + rightCount
749750

750-
val parentNodeAgg = leftImpurityCalculator.copy
751-
parentNodeAgg.add(rightImpurityCalculator)
751+
752752
// impurity of parent node
753753
val impurity = if (level > 0) {
754754
topImpurity
755755
} else {
756+
val parentNodeAgg = leftImpurityCalculator.copy
757+
parentNodeAgg.add(rightImpurityCalculator)
756758
parentNodeAgg.calculate()
757759
}
758760

759-
val predict = parentNodeAgg.predict
760-
val prob = parentNodeAgg.prob(predict)
761-
762761
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
763762
val rightImpurity = rightImpurityCalculator.calculate()
764763

@@ -770,7 +769,18 @@ object DecisionTree extends Serializable with Logging {
770769
return InformationGainStats.invalidInformationGainStats
771770
}
772771

773-
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
772+
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
773+
}
774+
775+
private def calculatePredict(
776+
leftImpurityCalculator: ImpurityCalculator,
777+
rightImpurityCalculator: ImpurityCalculator): Predict = {
778+
val parentNodeAgg = leftImpurityCalculator.copy
779+
parentNodeAgg.add(rightImpurityCalculator)
780+
val predict = parentNodeAgg.predict
781+
val prob = parentNodeAgg.prob(predict)
782+
783+
new Predict(predict, prob)
774784
}
775785

776786
/**
@@ -786,12 +796,14 @@ object DecisionTree extends Serializable with Logging {
786796
nodeImpurity: Double,
787797
level: Int,
788798
metadata: DecisionTreeMetadata,
789-
splits: Array[Array[Split]]): (Split, InformationGainStats) = {
799+
splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
790800

791801
logDebug("node impurity = " + nodeImpurity)
792802

803+
var predict: Option[Predict] = None
804+
793805
// For each (feature, split), calculate the gain, and select the best (feature, split).
794-
Range(0, metadata.numFeatures).map { featureIndex =>
806+
val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
795807
val numSplits = metadata.numSplits(featureIndex)
796808
if (metadata.isContinuous(featureIndex)) {
797809
// Cumulative sum (scanLeft) of bin statistics.
@@ -809,6 +821,7 @@ object DecisionTree extends Serializable with Logging {
809821
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
810822
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
811823
rightChildStats.subtract(leftChildStats)
824+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
812825
val gainStats =
813826
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
814827
(splitIdx, gainStats)
@@ -825,6 +838,7 @@ object DecisionTree extends Serializable with Logging {
825838
Range(0, numSplits).map { splitIndex =>
826839
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
827840
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
841+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
828842
val gainStats =
829843
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
830844
(splitIndex, gainStats)
@@ -899,6 +913,7 @@ object DecisionTree extends Serializable with Logging {
899913
val rightChildStats =
900914
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
901915
rightChildStats.subtract(leftChildStats)
916+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
902917
val gainStats =
903918
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
904919
(splitIndex, gainStats)
@@ -913,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
913928
(bestFeatureSplit, bestFeatureGainStats)
914929
}
915930
}.maxBy(_._2.gain)
931+
932+
require(predict.isDefined, "must calculate predict for each node")
933+
934+
(bestSplit, bestSplitStats, predict.get)
916935
}
917936

918937
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,21 @@ import org.apache.spark.annotation.DeveloperApi
2626
* @param impurity current node impurity
2727
* @param leftImpurity left node impurity
2828
* @param rightImpurity right node impurity
29-
* @param predict predicted value
30-
* @param prob probability of the label (classification only)
3129
*/
3230
@DeveloperApi
3331
class InformationGainStats(
3432
val gain: Double,
3533
val impurity: Double,
3634
val leftImpurity: Double,
37-
val rightImpurity: Double,
38-
val predict: Double,
39-
val prob: Double = 0.0) extends Serializable {
35+
val rightImpurity: Double) extends Serializable {
4036

4137
override def toString = {
42-
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
43-
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
38+
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
39+
.format(gain, impurity, leftImpurity, rightImpurity)
4440
}
4541
}
4642

4743

4844
private[tree] object InformationGainStats {
49-
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 0.0)
45+
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
5046
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.model
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
22+
/**
23+
* :: DeveloperApi ::
24+
* Predicted value for a node
25+
* @param predict predicted value
26+
* @param prob probability of the label (classification only)
27+
*/
28+
@DeveloperApi
29+
class Predict(
30+
val predict: Double,
31+
val prob: Double = 0.0) extends Serializable{
32+
33+
override def toString = {
34+
"predict = %f, prob = %f".format(predict, prob)
35+
}
36+
}

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
280280
assert(split.threshold === Double.MinValue)
281281

282282
val stats = bestSplits(0)._2
283+
val predict = bestSplits(0)._3
283284
assert(stats.gain > 0)
284-
assert(stats.predict === 1)
285-
assert(stats.prob === 0.6)
285+
assert(predict.predict === 1)
286+
assert(predict.prob === 0.6)
286287
assert(stats.impurity > 0.2)
287288
}
288289

@@ -313,8 +314,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
313314
assert(split.threshold === Double.MinValue)
314315

315316
val stats = bestSplits(0)._2
317+
val predict = bestSplits(0)._3.predict
316318
assert(stats.gain > 0)
317-
assert(stats.predict === 0.6)
319+
assert(predict === 0.6)
318320
assert(stats.impurity > 0.2)
319321
}
320322

@@ -392,7 +394,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
392394
assert(bestSplits(0)._2.gain === 0)
393395
assert(bestSplits(0)._2.leftImpurity === 0)
394396
assert(bestSplits(0)._2.rightImpurity === 0)
395-
assert(bestSplits(0)._2.predict === 1)
397+
assert(bestSplits(0)._3.predict === 1)
396398
}
397399

398400
test("Binary classification stump with fixed label 0 for Entropy") {
@@ -421,7 +423,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
421423
assert(bestSplits(0)._2.gain === 0)
422424
assert(bestSplits(0)._2.leftImpurity === 0)
423425
assert(bestSplits(0)._2.rightImpurity === 0)
424-
assert(bestSplits(0)._2.predict === 0)
426+
assert(bestSplits(0)._3.predict === 0)
425427
}
426428

427429
test("Binary classification stump with fixed label 1 for Entropy") {
@@ -450,7 +452,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
450452
assert(bestSplits(0)._2.gain === 0)
451453
assert(bestSplits(0)._2.leftImpurity === 0)
452454
assert(bestSplits(0)._2.rightImpurity === 0)
453-
assert(bestSplits(0)._2.predict === 1)
455+
assert(bestSplits(0)._3.predict === 1)
454456
}
455457

456458
test("Second level node building with vs. without groups") {
@@ -501,7 +503,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
501503
assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
502504
assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
503505
assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
504-
assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
506+
assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
505507
}
506508
}
507509

0 commit comments

Comments
 (0)