Skip to content

Commit 06b1690

Browse files
committed
fixed off-by-one error in bin to split conversion
1 parent 9cc3e31 commit 06b1690

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ object DecisionTree extends Serializable with Logging {
10121012
= binData(shift + numClasses * splitIndex + innerClassIndex) +
10131013
leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
10141014
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
1015-
binData(shift + (numClasses * (numBins - 2 - splitIndex) + innerClassIndex)) +
1015+
binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
10161016
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
10171017
innerClassIndex += 1
10181018
}
@@ -1077,13 +1077,13 @@ object DecisionTree extends Serializable with Logging {
10771077
// calculating right node aggregate for a split as a sum of right node aggregate of a
10781078
// higher split and the right bin aggregate of a bin where the split is a low split
10791079
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(0) =
1080-
binData(shift + (3 * (numBins - 2 - splitIndex))) +
1080+
binData(shift + (3 * (numBins - 1 - splitIndex))) +
10811081
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(0)
10821082
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(1) =
1083-
binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
1083+
binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
10841084
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(1)
10851085
rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(2) =
1086-
binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
1086+
binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
10871087
rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(2)
10881088

10891089
splitIndex += 1

mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
412412

413413
val stats = bestSplits(0)._2
414414
assert(stats.gain > 0)
415-
assert(stats.predict === 0)
416-
assert(stats.prob > 0.5)
417-
assert(stats.prob < 0.6)
415+
assert(stats.predict === 1)
416+
assert(stats.prob == 0.6)
418417
assert(stats.impurity > 0.2)
419418
}
420419

@@ -440,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
440439

441440
val stats = bestSplits(0)._2
442441
assert(stats.gain > 0)
443-
assert(stats.predict > 0.4)
444-
assert(stats.predict < 0.5)
442+
assert(stats.predict == 0.6)
445443
assert(stats.impurity > 0.2)
446444
}
447445

@@ -657,7 +655,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
657655
val bestSplit = bestSplits(0)._1
658656
assert(bestSplit.feature === 0)
659657
assert(bestSplit.categories.length === 1)
660-
println(bestSplit)
661658
assert(bestSplit.categories.contains(1.0))
662659
assert(bestSplit.featureType === Categorical)
663660
}

0 commit comments

Comments
 (0)