@@ -836,10 +836,11 @@ object DecisionTree extends Serializable with Logging {
836836 calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
837837 (splitIdx, gainStats)
838838 }.maxBy(_._2.gain)
839- if (bestFeatureGainStats == InformationGainStats .invalidInformationGainStats ) {
839+ if (bestFeatureGainStats.gain < metadata.minInfoGain ) {
840840 (Split .noSplit, InformationGainStats .invalidInformationGainStats)
841+ } else {
842+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
841843 }
842- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
843844 } else if (metadata.isUnordered(featureIndex)) {
844845 // Unordered categorical feature
845846 val (leftChildOffset, rightChildOffset) =
@@ -855,8 +856,9 @@ object DecisionTree extends Serializable with Logging {
855856 }.maxBy(_._2.gain)
856857 if (bestFeatureGainStats == InformationGainStats .invalidInformationGainStats) {
857858 (Split .noSplit, InformationGainStats .invalidInformationGainStats)
859+ } else {
860+ (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
858861 }
859- (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
860862 } else {
861863 // Ordered categorical feature
862864 val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
@@ -930,12 +932,13 @@ object DecisionTree extends Serializable with Logging {
930932 }.maxBy(_._2.gain)
931933 if (bestFeatureGainStats == InformationGainStats .invalidInformationGainStats) {
932934 (Split .noSplit, InformationGainStats .invalidInformationGainStats)
935+ } else {
936+ val categoriesForSplit =
937+ categoriesSortedByCentroid.map(_._1.toDouble).slice(0 , bestFeatureSplitIndex + 1 )
938+ val bestFeatureSplit =
939+ new Split (featureIndex, Double .MinValue , Categorical , categoriesForSplit)
940+ (bestFeatureSplit, bestFeatureGainStats)
933941 }
934- val categoriesForSplit =
935- categoriesSortedByCentroid.map(_._1.toDouble).slice(0 , bestFeatureSplitIndex + 1 )
936- val bestFeatureSplit =
937- new Split (featureIndex, Double .MinValue , Categorical , categoriesForSplit)
938- (bestFeatureSplit, bestFeatureGainStats)
939942 }
940943 }.maxBy(_._2.gain)
941944
0 commit comments