Skip to content

Commit 46b891f

Browse files
author
qiping.lqp
committed
fix bug
1 parent e72c7e4 commit 46b891f

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -836,10 +836,11 @@ object DecisionTree extends Serializable with Logging {
836836
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
837837
(splitIdx, gainStats)
838838
}.maxBy(_._2.gain)
839-
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
839+
if (bestFeatureGainStats.gain < metadata.minInfoGain) {
840840
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
841+
} else {
842+
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
841843
}
842-
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
843844
} else if (metadata.isUnordered(featureIndex)) {
844845
// Unordered categorical feature
845846
val (leftChildOffset, rightChildOffset) =
@@ -855,8 +856,9 @@ object DecisionTree extends Serializable with Logging {
855856
}.maxBy(_._2.gain)
856857
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
857858
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
859+
} else {
860+
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
858861
}
859-
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
860862
} else {
861863
// Ordered categorical feature
862864
val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex)
@@ -930,12 +932,13 @@ object DecisionTree extends Serializable with Logging {
930932
}.maxBy(_._2.gain)
931933
if (bestFeatureGainStats == InformationGainStats.invalidInformationGainStats) {
932934
(Split.noSplit, InformationGainStats.invalidInformationGainStats)
935+
} else {
936+
val categoriesForSplit =
937+
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
938+
val bestFeatureSplit =
939+
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
940+
(bestFeatureSplit, bestFeatureGainStats)
933941
}
934-
val categoriesForSplit =
935-
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
936-
val bestFeatureSplit =
937-
new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
938-
(bestFeatureSplit, bestFeatureGainStats)
939942
}
940943
}.maxBy(_._2.gain)
941944

0 commit comments

Comments
 (0)