Skip to content

Commit 19b01af

Browse files
committed
Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune
2 parents c6e2dfc + f1d11d1 commit 19b01af

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
683683
arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
684684

685685
val input = sc.parallelize(arr)
686-
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
687-
numClassesForClassification = 2, minInstancesPerNode = 4)
686+
val strategy = new Strategy(algo = Classification, impurity = Gini,
687+
maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
688688

689689
val model = DecisionTree.train(input, strategy)
690690
assert(model.topNode.isLeaf)
@@ -701,11 +701,37 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
701701
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
702702
new Array[Node](0), splits, bins, 10)
703703

704-
assert(bestSplits.length === 1)
704+
assert(bestSplits.length == 1)
705705
val bestInfoStats = bestSplits(0)._2
706706
assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
707707
}
708708

709+
test("don't choose split that doesn't satisfy min instance per node requirements") {
710+
// if a split doesn't satisfy min instances per node requirements,
711+
// this split is invalid, even though the information gain of split is large.
712+
val arr = new Array[LabeledPoint](4)
713+
arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
714+
arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
715+
arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
716+
arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
717+
718+
val input = sc.parallelize(arr)
719+
val strategy = new Strategy(algo = Classification, impurity = Gini,
720+
maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
721+
numClassesForClassification = 2, minInstancesPerNode = 2)
722+
val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
723+
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
724+
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
725+
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
726+
new Array[Node](0), splits, bins, 10)
727+
728+
assert(bestSplits.length == 1)
729+
val bestSplit = bestSplits(0)._1
730+
val bestSplitStats = bestSplits(0)._1
731+
assert(bestSplit.feature == 1)
732+
assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
733+
}
734+
709735
test("split must satisfy min info gain requirements") {
710736
val arr = new Array[LabeledPoint](3)
711737
arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
@@ -731,7 +757,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
731757
val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
732758
new Array[Node](0), splits, bins, 10)
733759

734-
assert(bestSplits.length === 1)
760+
assert(bestSplits.length == 1)
735761
val bestInfoStats = bestSplits(0)._2
736762
assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
737763
}

0 commit comments

Comments
 (0)