@@ -683,8 +683,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
683683 arr(2 ) = new LabeledPoint (0.0 , Vectors .sparse(2 , Seq ((0 , 1.0 ))))
684684
685685 val input = sc.parallelize(arr)
686- val strategy = new Strategy (algo = Classification , impurity = Gini , maxDepth = 2 ,
687- numClassesForClassification = 2 , minInstancesPerNode = 4 )
686+ val strategy = new Strategy (algo = Classification , impurity = Gini ,
687+ maxDepth = 2 , numClassesForClassification = 2 , minInstancesPerNode = 2 )
688688
689689 val model = DecisionTree .train(input, strategy)
690690 assert(model.topNode.isLeaf)
@@ -701,11 +701,37 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
701701 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (8 ), metadata, 0 ,
702702 new Array [Node ](0 ), splits, bins, 10 )
703703
704- assert(bestSplits.length === 1 )
704+ assert(bestSplits.length == 1 )
705705 val bestInfoStats = bestSplits(0 )._2
706706 assert(bestInfoStats == InformationGainStats .invalidInformationGainStats)
707707 }
708708
709+ test(" don't choose split that doesn't satisfy min instance per node requirements" ) {
710+ // if a split doesn't satisfy min instances per node requirements,
711+ // this split is invalid, even though the information gain of split is large.
712+ val arr = new Array [LabeledPoint ](4 )
713+ arr(0 ) = new LabeledPoint (0.0 , Vectors .dense(0.0 , 1.0 ))
714+ arr(1 ) = new LabeledPoint (1.0 , Vectors .dense(1.0 , 1.0 ))
715+ arr(2 ) = new LabeledPoint (0.0 , Vectors .dense(0.0 , 0.0 ))
716+ arr(3 ) = new LabeledPoint (0.0 , Vectors .dense(0.0 , 0.0 ))
717+
718+ val input = sc.parallelize(arr)
719+ val strategy = new Strategy (algo = Classification , impurity = Gini ,
720+ maxBins = 2 , maxDepth = 2 , categoricalFeaturesInfo = Map (0 -> 2 , 1 -> 2 ),
721+ numClassesForClassification = 2 , minInstancesPerNode = 2 )
722+ val metadata = DecisionTreeMetadata .buildMetadata(input, strategy)
723+ val (splits, bins) = DecisionTree .findSplitsBins(input, metadata)
724+ val treeInput = TreePoint .convertToTreeRDD(input, bins, metadata)
725+ val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (8 ), metadata, 0 ,
726+ new Array [Node ](0 ), splits, bins, 10 )
727+
728+ assert(bestSplits.length == 1 )
729+ val bestSplit = bestSplits(0 )._1
730+ val bestSplitStats = bestSplits(0 )._1
731+ assert(bestSplit.feature == 1 )
732+ assert(bestSplitStats != InformationGainStats .invalidInformationGainStats)
733+ }
734+
709735 test(" split must satisfy min info gain requirements" ) {
710736 val arr = new Array [LabeledPoint ](3 )
711737 arr(0 ) = new LabeledPoint (0.0 , Vectors .sparse(2 , Seq ((0 , 0.0 ))))
@@ -731,7 +757,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
731757 val bestSplits = DecisionTree .findBestSplits(treeInput, new Array (8 ), metadata, 0 ,
732758 new Array [Node ](0 ), splits, bins, 10 )
733759
734- assert(bestSplits.length === 1 )
760+ assert(bestSplits.length == 1 )
735761 val bestInfoStats = bestSplits(0 )._2
736762 assert(bestInfoStats == InformationGainStats .invalidInformationGainStats)
737763 }
0 commit comments