Skip to content

Commit 1dd2735

Browse files
committed
bin search logic for multiclass
1 parent f16a9bb commit 1dd2735

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,9 @@ object DecisionTree extends Serializable with Logging {
549549
* Sequential search helper method to find bin for categorical feature in multiclass
550550
* classification. Dummy value of 0 used since it is not used in future calculation
551551
*/
552-
def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = 0
552+
def sequentialBinSearchForCategoricalFeatureInBinaryClassification(): Int = {
553+
labeledPoint.features(featureIndex).toInt
554+
}
553555

554556
/**
555557
* Sequential search helper method to find bin for categorical feature.
@@ -662,7 +664,7 @@ object DecisionTree extends Serializable with Logging {
662664
label.toInt match {
663665
case n: Int =>
664666
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
665-
if (isFeatureContinuous && strategy.isMultiClassification) {
667+
if (!isFeatureContinuous && strategy.isMultiClassification) {
666668
// Find all matching bins and increment their values
667669
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
668670
val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1

0 commit comments

Comments
 (0)