Skip to content

Commit 79cdb9b

Browse files
qiping.lqpmengxr
authored andcommitted
[SPARK-2207][SPARK-3272][MLLib]Add minimum information gain and minimum instances per node as training parameters for decision tree.
These two parameters can act as early stop rules to do pre-pruning. When a split cause cause left or right child to have less than `minInstancesPerNode` or has less information gain than `minInfoGain`, current node will not be split by this split. When there is no possible splits that satisfy requirements, there is no useful information gain stats, but we still need to calculate the predict value for current node. So I separated calculation of predict from calculation of information gain, which can also save computation when the number of possible splits is large. Please see [SPARK-3272](https://issues.apache.org/jira/browse/SPARK-3272) for more details. CC: mengxr manishamde jkbradley, please help me review this, thanks. Author: qiping.lqp <[email protected]> Author: chouqin <[email protected]> Closes #2332 from chouqin/dt-preprune and squashes the following commits: f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree
1 parent 558962a commit 79cdb9b

File tree

7 files changed

+213
-36
lines changed

7 files changed

+213
-36
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
130130

131131
// Find best split for all nodes at a level.
132132
timer.start("findBestSplits")
133-
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
133+
val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
134134
DecisionTree.findBestSplits(treeInput, parentImpurities,
135135
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
136136
timer.stop("findBestSplits")
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
143143
timer.start("extractNodeInfo")
144144
val split = nodeSplitStats._1
145145
val stats = nodeSplitStats._2
146+
val predict = nodeSplitStats._3.predict
146147
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
147-
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
148+
val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
148149
logDebug("Node = " + node)
149150
nodes(nodeIndex) = node
150151
timer.stop("extractNodeInfo")
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
425426
splits: Array[Array[Split]],
426427
bins: Array[Array[Bin]],
427428
maxLevelForSingleGroup: Int,
428-
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
429+
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
429430
// split into groups to avoid memory overflow during aggregation
430431
if (level > maxLevelForSingleGroup) {
431432
// When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
434435
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
435436
val numGroups = 1 << level - maxLevelForSingleGroup
436437
logDebug("numGroups = " + numGroups)
437-
var bestSplits = new Array[(Split, InformationGainStats)](0)
438+
var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
438439
// Iterate over each group of nodes at a level.
439440
var groupIndex = 0
440441
while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
605606
bins: Array[Array[Bin]],
606607
timer: TimeTracker,
607608
numGroups: Int = 1,
608-
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
609+
groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
609610

610611
/*
611612
* The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
705706

706707
// Calculate best splits for all nodes at a given level
707708
timer.start("chooseSplits")
708-
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
709+
val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
709710
// Iterating over all nodes at this level
710711
var nodeIndex = 0
711712
while (nodeIndex < numNodes) {
@@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
734735
topImpurity: Double,
735736
level: Int,
736737
metadata: DecisionTreeMetadata): InformationGainStats = {
737-
738738
val leftCount = leftImpurityCalculator.count
739739
val rightCount = rightImpurityCalculator.count
740740

741-
val totalCount = leftCount + rightCount
742-
if (totalCount == 0) {
743-
// Return arbitrary prediction.
744-
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
741+
// If left child or right child doesn't satisfy minimum instances per node,
742+
// then this split is invalid, return invalid information gain stats.
743+
if ((leftCount < metadata.minInstancesPerNode) ||
744+
(rightCount < metadata.minInstancesPerNode)) {
745+
return InformationGainStats.invalidInformationGainStats
745746
}
746747

747-
val parentNodeAgg = leftImpurityCalculator.copy
748-
parentNodeAgg.add(rightImpurityCalculator)
748+
val totalCount = leftCount + rightCount
749+
749750
// impurity of parent node
750751
val impurity = if (level > 0) {
751752
topImpurity
752753
} else {
754+
val parentNodeAgg = leftImpurityCalculator.copy
755+
parentNodeAgg.add(rightImpurityCalculator)
753756
parentNodeAgg.calculate()
754757
}
755758

756-
val predict = parentNodeAgg.predict
757-
val prob = parentNodeAgg.prob(predict)
758-
759759
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
760760
val rightImpurity = rightImpurityCalculator.calculate()
761761

@@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {
764764

765765
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
766766

767-
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
767+
// if information gain doesn't satisfy minimum information gain,
768+
// then this split is invalid, return invalid information gain stats.
769+
if (gain < metadata.minInfoGain) {
770+
return InformationGainStats.invalidInformationGainStats
771+
}
772+
773+
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
774+
}
775+
776+
/**
777+
* Calculate predict value for current node, given stats of any split.
778+
* Note that this function is called only once for each node.
779+
* @param leftImpurityCalculator left node aggregates for a split
780+
* @param rightImpurityCalculator right node aggregates for a node
781+
* @return predict value for current node
782+
*/
783+
private def calculatePredict(
784+
leftImpurityCalculator: ImpurityCalculator,
785+
rightImpurityCalculator: ImpurityCalculator): Predict = {
786+
val parentNodeAgg = leftImpurityCalculator.copy
787+
parentNodeAgg.add(rightImpurityCalculator)
788+
val predict = parentNodeAgg.predict
789+
val prob = parentNodeAgg.prob(predict)
790+
791+
new Predict(predict, prob)
768792
}
769793

770794
/**
@@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
780804
nodeImpurity: Double,
781805
level: Int,
782806
metadata: DecisionTreeMetadata,
783-
splits: Array[Array[Split]]): (Split, InformationGainStats) = {
807+
splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
784808

785809
logDebug("node impurity = " + nodeImpurity)
786810

811+
// calculate predict only once
812+
var predict: Option[Predict] = None
813+
787814
// For each (feature, split), calculate the gain, and select the best (feature, split).
788-
Range(0, metadata.numFeatures).map { featureIndex =>
815+
val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
789816
val numSplits = metadata.numSplits(featureIndex)
790817
if (metadata.isContinuous(featureIndex)) {
791818
// Cumulative sum (scanLeft) of bin statistics.
@@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
803830
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
804831
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
805832
rightChildStats.subtract(leftChildStats)
833+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
806834
val gainStats =
807835
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
808836
(splitIdx, gainStats)
@@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
816844
Range(0, numSplits).map { splitIndex =>
817845
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
818846
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
847+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
819848
val gainStats =
820849
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
821850
(splitIndex, gainStats)
@@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
887916
val rightChildStats =
888917
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
889918
rightChildStats.subtract(leftChildStats)
919+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
890920
val gainStats =
891921
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
892922
(splitIndex, gainStats)
@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
898928
(bestFeatureSplit, bestFeatureGainStats)
899929
}
900930
}.maxBy(_._2.gain)
931+
932+
require(predict.isDefined, "must calculate predict for each node")
933+
934+
(bestSplit, bestSplitStats, predict.get)
901935
}
902936

903937
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
4949
* k) implies the feature n is categorical with k categories 0,
5050
* 1, 2, ... , k-1. It's important to note that features are
5151
* zero-indexed.
52+
* @param minInstancesPerNode Minimum number of instances each child must have after split.
53+
* Default value is 1. If a split cause left or right child
54+
* to have less than minInstancesPerNode,
55+
* this split will not be considered as a valid split.
56+
* @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
57+
* If a split has less information gain than minInfoGain,
58+
* this split will not be considered as a valid split.
5259
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
5360
* 256 MB.
5461
*/
@@ -61,6 +68,8 @@ class Strategy (
6168
val maxBins: Int = 32,
6269
val quantileCalculationStrategy: QuantileStrategy = Sort,
6370
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
71+
val minInstancesPerNode: Int = 1,
72+
val minInfoGain: Double = 0.0,
6473
val maxMemoryInMB: Int = 256) extends Serializable {
6574

6675
if (algo == Classification) {

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
4545
val unorderedFeatures: Set[Int],
4646
val numBins: Array[Int],
4747
val impurity: Impurity,
48-
val quantileStrategy: QuantileStrategy) extends Serializable {
48+
val quantileStrategy: QuantileStrategy,
49+
val minInstancesPerNode: Int,
50+
val minInfoGain: Double) extends Serializable {
4951

5052
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
5153

@@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {
127129

128130
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
129131
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
130-
strategy.impurity, strategy.quantileCalculationStrategy)
132+
strategy.impurity, strategy.quantileCalculationStrategy,
133+
strategy.minInstancesPerNode, strategy.minInfoGain)
131134
}
132135

133136
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi
2626
* @param impurity current node impurity
2727
* @param leftImpurity left node impurity
2828
* @param rightImpurity right node impurity
29-
* @param predict predicted value
30-
* @param prob probability of the label (classification only)
3129
*/
3230
@DeveloperApi
3331
class InformationGainStats(
3432
val gain: Double,
3533
val impurity: Double,
3634
val leftImpurity: Double,
37-
val rightImpurity: Double,
38-
val predict: Double,
39-
val prob: Double = 0.0) extends Serializable {
35+
val rightImpurity: Double) extends Serializable {
4036

4137
override def toString = {
42-
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
43-
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
38+
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
39+
.format(gain, impurity, leftImpurity, rightImpurity)
4440
}
4541
}
42+
43+
44+
private[tree] object InformationGainStats {
45+
/**
46+
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
47+
* denote that current split doesn't satisfies minimum info gain or
48+
* minimum number of instances per node.
49+
*/
50+
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
51+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.tree.model
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
22+
/**
23+
* :: DeveloperApi ::
24+
* Predicted value for a node
25+
* @param predict predicted value
26+
* @param prob probability of the label (classification only)
27+
*/
28+
@DeveloperApi
29+
private[tree] class Predict(
30+
val predict: Double,
31+
val prob: Double = 0.0) extends Serializable{
32+
33+
override def toString = {
34+
"predict = %f, prob = %f".format(predict, prob)
35+
}
36+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model
1919

2020
import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
22+
import org.apache.spark.mllib.tree.configuration.FeatureType
23+
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
2224

2325
/**
2426
* :: DeveloperApi ::

0 commit comments

Comments
 (0)