Skip to content

Commit ac42378

Browse files
author
qiping.lqp
committed
add min info gain and min instances per node parameters in decision tree
1 parent 7db5339 commit ac42378

File tree

6 files changed

+77
-7
lines changed

6 files changed

+77
-7
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -738,12 +738,15 @@ object DecisionTree extends Serializable with Logging {
738738
val leftCount = leftImpurityCalculator.count
739739
val rightCount = rightImpurityCalculator.count
740740

741-
val totalCount = leftCount + rightCount
742-
if (totalCount == 0) {
743-
// Return arbitrary prediction.
744-
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
741+
// If left child or right child doesn't satisfy minimum instances per node,
742+
// then this split is invalid, return invalid information gain stats
743+
if ((leftCount < metadata.minInstancesPerNode) ||
744+
(rightCount < metadata.minInstancesPerNode)) {
745+
return InformationGainStats.invalidInformationGainStats
745746
}
746747

748+
val totalCount = leftCount + rightCount
749+
747750
val parentNodeAgg = leftImpurityCalculator.copy
748751
parentNodeAgg.add(rightImpurityCalculator)
749752
// impurity of parent node
@@ -763,6 +766,9 @@ object DecisionTree extends Serializable with Logging {
763766
val rightWeight = rightCount / totalCount.toDouble
764767

765768
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
769+
if (gain < metadata.minInfoGain) {
770+
return InformationGainStats.invalidInformationGainStats
771+
}
766772

767773
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
768774
}
@@ -807,6 +813,9 @@ object DecisionTree extends Serializable with Logging {
807813
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
808814
(splitIdx, gainStats)
809815
}.maxBy(_._2.gain)
816+
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
817+
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
818+
}
810819
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
811820
} else if (metadata.isUnordered(featureIndex)) {
812821
// Unordered categorical feature
@@ -820,6 +829,9 @@ object DecisionTree extends Serializable with Logging {
820829
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
821830
(splitIndex, gainStats)
822831
}.maxBy(_._2.gain)
832+
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
833+
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
834+
}
823835
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
824836
} else {
825837
// Ordered categorical feature
@@ -891,6 +903,9 @@ object DecisionTree extends Serializable with Logging {
891903
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
892904
(splitIndex, gainStats)
893905
}.maxBy(_._2.gain)
906+
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
907+
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
908+
}
894909
val categoriesForSplit =
895910
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
896911
val bestFeatureSplit =

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class Strategy (
6161
val maxBins: Int = 100,
6262
val quantileCalculationStrategy: QuantileStrategy = Sort,
6363
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
64+
val minInstancesPerNode: Int = 0,
65+
val minInfoGain: Double = 0.0,
6466
val maxMemoryInMB: Int = 128) extends Serializable {
6567

6668
if (algo == Classification) {

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
4545
val unorderedFeatures: Set[Int],
4646
val numBins: Array[Int],
4747
val impurity: Impurity,
48-
val quantileStrategy: QuantileStrategy) extends Serializable {
48+
val quantileStrategy: QuantileStrategy,
49+
val minInstancesPerNode: Int,
50+
val minInfoGain: Double) extends Serializable {
4951

5052
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
5153

@@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {
127129

128130
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
129131
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
130-
strategy.impurity, strategy.quantileCalculationStrategy)
132+
strategy.impurity, strategy.quantileCalculationStrategy,
133+
strategy.minInstancesPerNode, strategy.minInfoGain)
131134
}
132135

133136
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,8 @@ class InformationGainStats(
4343
.format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
4444
}
4545
}
46+
47+
48+
private[tree] object InformationGainStats {
49+
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 0.0)
50+
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model
1919

2020
import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
22+
import org.apache.spark.mllib.tree.configuration.FeatureType
23+
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
2224

2325
/**
2426
* :: DeveloperApi ::
@@ -66,3 +68,7 @@ private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType)
6668
private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
6769
extends Split(feature, Double.MaxValue, featureType, List())
6870

71+
72+
private[tree] object Split {
73+
val noSplit = new Split(-1, Double.MinValue, FeatureType.Continuous, List())
74+
}

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
2828
import org.apache.spark.mllib.tree.configuration.Strategy
2929
import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
3030
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
31-
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
31+
import org.apache.spark.mllib.tree.model.{Split, DecisionTreeModel, Node}
3232
import org.apache.spark.mllib.util.LocalSparkContext
3333

3434

@@ -684,6 +684,45 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
684684
validateClassifier(model, arr, 0.6)
685685
}
686686

687+
test("split must satisfy min instances per node requirements") {
688+
val arr = new Array[LabeledPoint](3)
689+
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
690+
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
691+
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
692+
693+
val input = sc.parallelize(arr)
694+
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
695+
numClassesForClassification = 2, minInstancesPerNode = 4)
696+
697+
val model = DecisionTree.train(input, strategy)
698+
assert(model.topNode.isLeaf)
699+
assert(model.topNode.predict == 0.0)
700+
assert(model.topNode.split.get == Split.noSplit)
701+
val predicts = input.map(p => model.predict(p.features)).collect()
702+
predicts.foreach { predict =>
703+
assert(predict == 0.0)
704+
}
705+
}
706+
707+
test("split must satisfy min info gain requirements") {
708+
val arr = new Array[LabeledPoint](3)
709+
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
710+
arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
711+
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
712+
713+
val input = sc.parallelize(arr)
714+
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
715+
numClassesForClassification = 2, minInfoGain = 1.0)
716+
717+
val model = DecisionTree.train(input, strategy)
718+
assert(model.topNode.isLeaf)
719+
assert(model.topNode.predict == 0.0)
720+
assert(model.topNode.split.get == Split.noSplit)
721+
val predicts = input.map(p => model.predict(p.features)).collect()
722+
predicts.foreach { predict =>
723+
assert(predict == 0.0)
724+
}
725+
}
687726
}
688727

689728
object DecisionTreeSuite {

0 commit comments

Comments
 (0)